[ 
https://issues.apache.org/jira/browse/KAFKA-7285?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16581760#comment-16581760
 ] 

ASF GitHub Bot commented on KAFKA-7285:
---------------------------------------

mjsax closed pull request #5501: KAFKA-7285: Create new producer on each 
rebalance if EOS enabled
URL: https://github.com/apache/kafka/pull/5501
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java 
b/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java
index dc00b473027..538e59c68e1 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java
@@ -67,6 +67,7 @@
     private boolean transactionCommitted;
     private boolean transactionAborted;
     private boolean producerFenced;
+    private boolean producerFencedOnClose;
     private boolean sentOffsets;
     private long commitCount = 0L;
     private Map<MetricName, Metric> mockMetrics;
@@ -311,6 +312,9 @@ public void close() {
 
     @Override
     public void close(long timeout, TimeUnit timeUnit) {
+        if (producerFencedOnClose) {
+            throw new ProducerFencedException("MockProducer is fenced.");
+        }
         this.closed = true;
     }
 
@@ -324,6 +328,12 @@ public void fenceProducer() {
         this.producerFenced = true;
     }
 
+    public void fenceProducerOnClose() {
+        verifyProducerState();
+        verifyTransactionsInitialized();
+        this.producerFencedOnClose = true;
+    }
+
     public boolean transactionInitialized() {
         return this.transactionInitialized;
     }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java
index bf10da2b5e7..09de11d5245 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java
@@ -44,6 +44,12 @@
                      final Serializer<V> valueSerializer,
                      final StreamPartitioner<? super K, ? super V> 
partitioner);
 
+    /**
+     * Initialize the collector with a producer.
+     * @param producer the producer that should be used by this collector
+     */
+    void init(final Producer<byte[], byte[]> producer);
+
     /**
      * Flush the internal {@link Producer}.
      */
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
index d753648eede..e48b4d1825f 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
@@ -51,7 +51,7 @@
     private final Logger log;
     private final String logPrefix;
     private final Sensor skippedRecordsSensor;
-    private final Producer<byte[], byte[]> producer;
+    private Producer<byte[], byte[]> producer;
     private final Map<TopicPartition, Long> offsets;
     private final ProductionExceptionHandler productionExceptionHandler;
 
@@ -61,12 +61,10 @@
     private final static String PARAMETER_HINT = "\nYou can increase producer 
parameter `retries` and `retry.backoff.ms` to avoid this error.";
     private volatile KafkaException sendException;
 
-    public RecordCollectorImpl(final Producer<byte[], byte[]> producer,
-                               final String streamTaskId,
+    public RecordCollectorImpl(final String streamTaskId,
                                final LogContext logContext,
                                final ProductionExceptionHandler 
productionExceptionHandler,
                                final Sensor skippedRecordsSensor) {
-        this.producer = producer;
         this.offsets = new HashMap<>();
         this.logPrefix = String.format("task [%s] ", streamTaskId);
         this.log = logContext.logger(getClass());
@@ -74,6 +72,11 @@ public RecordCollectorImpl(final Producer<byte[], byte[]> 
producer,
         this.skippedRecordsSensor = skippedRecordsSensor;
     }
 
+    @Override
+    public void init(final Producer<byte[], byte[]> producer) {
+        this.producer = producer;
+    }
+
     @Override
     public <K, V> void send(final String topic,
                             final K key,
@@ -239,6 +242,7 @@ public void flush() {
     public void close() {
         log.debug("Closing producer");
         producer.close();
+        producer = null;
         checkForException();
     }
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
index 18fe7043b40..67834d7a7cd 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyContextImpl.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.serialization.Serializer;
@@ -58,6 +59,9 @@
                                 final Serializer<V> valueSerializer,
                                 final StreamPartitioner<? super K, ? super V> 
partitioner) {}
 
+        @Override
+        public void init(final Producer<byte[], byte[]> producer) {}
+
         @Override
         public void flush() {}
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index 7835a544137..79df5d158a1 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -68,7 +68,8 @@
 
     private final Map<TopicPartition, Long> consumedOffsets;
     private final RecordCollector recordCollector;
-    private final Producer<byte[], byte[]> producer;
+    private final ProducerSupplier producerSupplier;
+    private Producer<byte[], byte[]> producer;
     private final int maxBufferedSize;
 
     private boolean commitRequested = false;
@@ -149,6 +150,10 @@ void removeAllSensors() {
         }
     }
 
+    public interface ProducerSupplier {
+        Producer<byte[], byte[]> get();
+    }
+
     public StreamTask(final TaskId id,
                       final Collection<TopicPartition> partitions,
                       final ProcessorTopology topology,
@@ -159,9 +164,9 @@ public StreamTask(final TaskId id,
                       final StateDirectory stateDirectory,
                       final ThreadCache cache,
                       final Time time,
-                      final Producer<byte[], byte[]> producer,
+                      final ProducerSupplier producerSupplier,
                       final Sensor closeSensor) {
-        this(id, partitions, topology, consumer, changelogReader, config, 
metrics, stateDirectory, cache, time, producer, null, closeSensor);
+        this(id, partitions, topology, consumer, changelogReader, config, 
metrics, stateDirectory, cache, time, producerSupplier, null, closeSensor);
     }
 
     public StreamTask(final TaskId id,
@@ -174,13 +179,14 @@ public StreamTask(final TaskId id,
                       final StateDirectory stateDirectory,
                       final ThreadCache cache,
                       final Time time,
-                      final Producer<byte[], byte[]> producer,
+                      final ProducerSupplier producerSupplier,
                       final RecordCollector recordCollector,
                       final Sensor closeSensor) {
         super(id, partitions, topology, consumer, changelogReader, false, 
stateDirectory, config);
 
         this.time = time;
-        this.producer = producer;
+        this.producerSupplier = producerSupplier;
+        this.producer = producerSupplier.get();
         this.closeSensor = closeSensor;
         this.taskMetrics = new TaskMetrics(id, metrics);
 
@@ -188,7 +194,6 @@ public StreamTask(final TaskId id,
 
         if (recordCollector == null) {
             this.recordCollector = new RecordCollectorImpl(
-                producer,
                 id.toString(),
                 logContext,
                 productionExceptionHandler,
@@ -197,6 +202,8 @@ public StreamTask(final TaskId id,
         } else {
             this.recordCollector = recordCollector;
         }
+        this.recordCollector.init(this.producer);
+
         streamTimePunctuationQueue = new PunctuationQueue();
         systemTimePunctuationQueue = new PunctuationQueue();
         maxBufferedSize = 
config.getInt(StreamsConfig.BUFFERED_RECORDS_PER_PARTITION_CONFIG);
@@ -279,8 +286,15 @@ public void initializeTopology() {
      */
     @Override
     public void resume() {
-        // nothing to do; new transaction will be started only after topology 
is initialized
         log.debug("Resuming");
+        if (eosEnabled) {
+            if (producer != null) {
+                throw new IllegalStateException("Task producer should be 
null.");
+            }
+            producer = producerSupplier.get();
+            producer.initTransactions();
+            recordCollector.init(producer);
+        }
     }
 
     /**
@@ -525,7 +539,7 @@ private void initTopology() {
     @Override
     public void suspend() {
         log.debug("Suspending");
-        suspend(true);
+        suspend(true, false);
     }
 
     /**
@@ -541,10 +555,64 @@ public void suspend() {
      *                               or if the task producer got fenced (EOS)
      */
     // visible for testing
-    void suspend(final boolean clean) {
-        closeTopology(); // should we call this only on clean suspend?
+    void suspend(final boolean clean,
+                 final boolean isZombie) {
+        try {
+            closeTopology(); // should we call this only on clean suspend?
+        } catch (final RuntimeException fatal) {
+            if (clean) {
+                throw fatal;
+            }
+        }
+
         if (clean) {
-            commit(false);
+            TaskMigratedException taskMigratedException = null;
+            try {
+                commit(false);
+            } finally {
+                if (eosEnabled) {
+                    try {
+                        recordCollector.close();
+                    } catch (final ProducerFencedException e) {
+                        taskMigratedException = new 
TaskMigratedException(this, e);
+                    } finally {
+                        producer = null;
+                    }
+                }
+            }
+            if (taskMigratedException != null) {
+                throw taskMigratedException;
+            }
+        } else {
+            maybeAbortTransactionAndCloseRecordCollector(isZombie);
+        }
+    }
+
+    private void maybeAbortTransactionAndCloseRecordCollector(final boolean 
isZombie) {
+        if (eosEnabled && !isZombie) {
+            try {
+                if (transactionInFlight) {
+                    producer.abortTransaction();
+                }
+                transactionInFlight = false;
+            } catch (final ProducerFencedException ignore) {
+                /* TODO
+                 * this should actually never happen atm as we guard the call 
to #abortTransaction
+                 * -> the reason for the guard is a "bug" in the Producer -- 
it throws IllegalStateException
+                 * instead of ProducerFencedException atm. We can remove the 
isZombie flag after KAFKA-5604 got
+                 * fixed and fall-back to this catch-and-swallow code
+                 */
+
+                // can be ignored: transaction got already aborted by 
brokers/transactional-coordinator if this happens
+            }
+
+            try {
+                recordCollector.close();
+            } catch (final Throwable e) {
+                log.error("Failed to close producer due to the following 
error:", e);
+            } finally {
+                producer = null;
+            }
         }
     }
 
@@ -589,37 +657,8 @@ public void closeSuspended(boolean clean,
             log.error("Could not close state manager due to the following 
error:", e);
         }
 
-        try {
-            partitionGroup.close();
-            taskMetrics.removeAllSensors();
-        } finally {
-            if (eosEnabled) {
-                if (!clean) {
-                    try {
-                        if (!isZombie && transactionInFlight) {
-                            producer.abortTransaction();
-                        }
-                        transactionInFlight = false;
-                    } catch (final ProducerFencedException ignore) {
-                        /* TODO
-                         * this should actually never happen atm as we guard 
the call to #abortTransaction
-                         * -> the reason for the guard is a "bug" in the 
Producer -- it throws IllegalStateException
-                         * instead of ProducerFencedException atm. We can 
remove the isZombie flag after KAFKA-5604 got
-                         * fixed and fall-back to this catch-and-swallow code
-                         */
-
-                        // can be ignored: transaction got already aborted by 
brokers/transactional-coordinator if this happens
-                    }
-                }
-                try {
-                    if (!isZombie) {
-                        recordCollector.close();
-                    }
-                } catch (final Throwable e) {
-                    log.error("Failed to close producer due to the following 
error:", e);
-                }
-            }
-        }
+        partitionGroup.close();
+        taskMetrics.removeAllSensors();
 
         closeSensor.record();
 
@@ -630,7 +669,7 @@ public void closeSuspended(boolean clean,
 
     /**
      * <pre>
-     * - {@link #suspend(boolean) suspend(clean)}
+     * - {@link #suspend(boolean, boolean) suspend(clean)}
      *   - close topology
      *   - if (clean) {@link #commit()}
      *     - flush state and producer
@@ -653,7 +692,7 @@ public void close(boolean clean,
 
         RuntimeException firstException = null;
         try {
-            suspend(clean);
+            suspend(clean, isZombie);
         } catch (final RuntimeException e) {
             clean = false;
             firstException = e;
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 968e5779e18..efd94eaf637 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -436,7 +436,7 @@ StreamTask createTask(final Consumer<byte[], byte[]> 
consumer,
                 stateDirectory,
                 cache,
                 time,
-                createProducer(taskId),
+                () -> createProducer(taskId),
                 streamsMetrics.tasksClosedSensor);
         }
 
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java
index 4e14143cf65..7c44258ff89 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java
@@ -92,7 +92,6 @@ public void testMetrics() {
         final InternalMockProcessorContext context = new 
InternalMockProcessorContext(
             anyStateSerde,
             new RecordCollectorImpl(
-                null,
                 null,
                 new LogContext("processnode-test "),
                 new DefaultProductionExceptionHandler(),
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
index 6954eda529f..4f89a1e756f 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
@@ -82,12 +82,12 @@ public Integer partition(final String topic, final String 
key, final Object valu
     public void testSpecificPartition() {
 
         final RecordCollectorImpl collector = new RecordCollectorImpl(
-            new MockProducer<>(cluster, true, new DefaultPartitioner(), 
byteArraySerializer, byteArraySerializer),
             "RecordCollectorTest-TestSpecificPartition",
             new LogContext("RecordCollectorTest-TestSpecificPartition "),
             new DefaultProductionExceptionHandler(),
             new Metrics().sensor("skipped-records")
         );
+        collector.init(new MockProducer<>(cluster, true, new 
DefaultPartitioner(), byteArraySerializer, byteArraySerializer));
 
         final Headers headers = new RecordHeaders(new Header[]{new 
RecordHeader("key", "value".getBytes())});
 
@@ -120,12 +120,12 @@ public void testSpecificPartition() {
     public void testStreamPartitioner() {
 
         final RecordCollectorImpl collector = new RecordCollectorImpl(
-            new MockProducer<>(cluster, true, new DefaultPartitioner(), 
byteArraySerializer, byteArraySerializer),
             "RecordCollectorTest-TestStreamPartitioner",
             new LogContext("RecordCollectorTest-TestStreamPartitioner "),
             new DefaultProductionExceptionHandler(),
             new Metrics().sensor("skipped-records")
         );
+        collector.init(new MockProducer<>(cluster, true, new 
DefaultPartitioner(), byteArraySerializer, byteArraySerializer));
 
         final Headers headers = new RecordHeaders(new Header[]{new 
RecordHeader("key", "value".getBytes())});
 
@@ -152,16 +152,16 @@ public void testStreamPartitioner() {
     @Test(expected = StreamsException.class)
     public void 
shouldThrowStreamsExceptionOnAnyExceptionButProducerFencedException() {
         final RecordCollector collector = new RecordCollectorImpl(
-            new MockProducer(cluster, true, new DefaultPartitioner(), 
byteArraySerializer, byteArraySerializer) {
-                @Override
-                public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
-                    throw new KafkaException();
-                }
-            },
             "test",
             logContext,
             new DefaultProductionExceptionHandler(),
             new Metrics().sensor("skipped-records"));
+        collector.init(new MockProducer(cluster, true, new 
DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
+            @Override
+            public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
+                throw new KafkaException();
+            }
+        });
 
         collector.send("topic1", "3", "0", null, null, stringSerializer, 
stringSerializer, streamPartitioner);
     }
@@ -170,17 +170,17 @@ public void 
shouldThrowStreamsExceptionOnAnyExceptionButProducerFencedException(
     @Test
     public void 
shouldThrowStreamsExceptionOnSubsequentCallIfASendFailsWithDefaultExceptionHandler()
 {
         final RecordCollector collector = new RecordCollectorImpl(
-            new MockProducer(cluster, true, new DefaultPartitioner(), 
byteArraySerializer, byteArraySerializer) {
-                @Override
-                public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
-                    callback.onCompletion(null, new Exception());
-                    return null;
-                }
-            },
             "test",
             logContext,
             new DefaultProductionExceptionHandler(),
             new Metrics().sensor("skipped-records"));
+        collector.init(new MockProducer(cluster, true, new 
DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
+            @Override
+            public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
+                callback.onCompletion(null, new Exception());
+                return null;
+            }
+        });
 
         collector.send("topic1", "3", "0", null, null, stringSerializer, 
stringSerializer, streamPartitioner);
 
@@ -194,17 +194,17 @@ public void 
shouldThrowStreamsExceptionOnSubsequentCallIfASendFailsWithDefaultEx
     @Test
     public void 
shouldNotThrowStreamsExceptionOnSubsequentCallIfASendFailsWithContinueExceptionHandler()
 {
         final RecordCollector collector = new RecordCollectorImpl(
-            new MockProducer(cluster, true, new DefaultPartitioner(), 
byteArraySerializer, byteArraySerializer) {
-                @Override
-                public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
-                    callback.onCompletion(null, new Exception());
-                    return null;
-                }
-            },
             "test",
             logContext,
             new AlwaysContinueProductionExceptionHandler(),
             new Metrics().sensor("skipped-records"));
+        collector.init(new MockProducer(cluster, true, new 
DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
+            @Override
+            public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
+                callback.onCompletion(null, new Exception());
+                return null;
+            }
+        });
 
         collector.send("topic1", "3", "0", null, null, stringSerializer, 
stringSerializer, streamPartitioner);
 
@@ -220,17 +220,17 @@ public void 
shouldRecordSkippedMetricAndLogWarningIfSendFailsWithContinueExcepti
         final MetricName metricName = new MetricName("name", "group", 
"description", Collections.EMPTY_MAP);
         sensor.add(metricName, new Sum());
         final RecordCollector collector = new RecordCollectorImpl(
-            new MockProducer(cluster, true, new DefaultPartitioner(), 
byteArraySerializer, byteArraySerializer) {
-                @Override
-                public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
-                    callback.onCompletion(null, new Exception());
-                    return null;
-                }
-            },
             "test",
             logContext,
             new AlwaysContinueProductionExceptionHandler(),
             sensor);
+        collector.init(new MockProducer(cluster, true, new 
DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
+            @Override
+            public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
+                callback.onCompletion(null, new Exception());
+                return null;
+            }
+        });
         collector.send("topic1", "3", "0", null, null, stringSerializer, 
stringSerializer, streamPartitioner);
         assertEquals(1.0, metrics.metrics().get(metricName).metricValue());
         assertTrue(logCaptureAppender.getMessages().contains("test Error 
sending records (key=[3] value=[0] timestamp=[null]) to topic=[topic1] and 
partition=[0]; The exception handler chose to CONTINUE processing in spite of 
this error."));
@@ -241,17 +241,17 @@ public void 
shouldRecordSkippedMetricAndLogWarningIfSendFailsWithContinueExcepti
     @Test
     public void 
shouldThrowStreamsExceptionOnFlushIfASendFailedWithDefaultExceptionHandler() {
         final RecordCollector collector = new RecordCollectorImpl(
-            new MockProducer(cluster, true, new DefaultPartitioner(), 
byteArraySerializer, byteArraySerializer) {
-                @Override
-                public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
-                    callback.onCompletion(null, new Exception());
-                    return null;
-                }
-            },
             "test",
             logContext,
             new DefaultProductionExceptionHandler(),
             new Metrics().sensor("skipped-records"));
+        collector.init(new MockProducer(cluster, true, new 
DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
+            @Override
+            public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
+                callback.onCompletion(null, new Exception());
+                return null;
+            }
+        });
 
         collector.send("topic1", "3", "0", null, null, stringSerializer, 
stringSerializer, streamPartitioner);
 
@@ -265,17 +265,17 @@ public void 
shouldThrowStreamsExceptionOnFlushIfASendFailedWithDefaultExceptionH
     @Test
     public void 
shouldNotThrowStreamsExceptionOnFlushIfASendFailedWithContinueExceptionHandler()
 {
         final RecordCollector collector = new RecordCollectorImpl(
-            new MockProducer(cluster, true, new DefaultPartitioner(), 
byteArraySerializer, byteArraySerializer) {
-                @Override
-                public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
-                    callback.onCompletion(null, new Exception());
-                    return null;
-                }
-            },
             "test",
             logContext,
             new AlwaysContinueProductionExceptionHandler(),
             new Metrics().sensor("skipped-records"));
+        collector.init(new MockProducer(cluster, true, new 
DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
+            @Override
+            public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
+                callback.onCompletion(null, new Exception());
+                return null;
+            }
+        });
 
         collector.send("topic1", "3", "0", null, null, stringSerializer, 
stringSerializer, streamPartitioner);
 
@@ -286,17 +286,17 @@ public void 
shouldNotThrowStreamsExceptionOnFlushIfASendFailedWithContinueExcept
     @Test
     public void 
shouldThrowStreamsExceptionOnCloseIfASendFailedWithDefaultExceptionHandler() {
         final RecordCollector collector = new RecordCollectorImpl(
-            new MockProducer(cluster, true, new DefaultPartitioner(), 
byteArraySerializer, byteArraySerializer) {
-                @Override
-                public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
-                    callback.onCompletion(null, new Exception());
-                    return null;
-                }
-            },
             "test",
             logContext,
             new DefaultProductionExceptionHandler(),
             new Metrics().sensor("skipped-records"));
+        collector.init(new MockProducer(cluster, true, new 
DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
+            @Override
+            public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
+                callback.onCompletion(null, new Exception());
+                return null;
+            }
+        });
 
         collector.send("topic1", "3", "0", null, null, stringSerializer, 
stringSerializer, streamPartitioner);
 
@@ -310,17 +310,17 @@ public void 
shouldThrowStreamsExceptionOnCloseIfASendFailedWithDefaultExceptionH
     @Test
     public void 
shouldNotThrowStreamsExceptionOnCloseIfASendFailedWithContinueExceptionHandler()
 {
         final RecordCollector collector = new RecordCollectorImpl(
-            new MockProducer(cluster, true, new DefaultPartitioner(), 
byteArraySerializer, byteArraySerializer) {
-                @Override
-                public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
-                    callback.onCompletion(null, new Exception());
-                    return null;
-                }
-            },
             "test",
             logContext,
             new AlwaysContinueProductionExceptionHandler(),
             new Metrics().sensor("skipped-records"));
+        collector.init(new MockProducer(cluster, true, new 
DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
+            @Override
+            public synchronized Future<RecordMetadata> send(final 
ProducerRecord record, final Callback callback) {
+                callback.onCompletion(null, new Exception());
+                return null;
+            }
+        });
 
         collector.send("topic1", "3", "0", null, null, stringSerializer, 
stringSerializer, streamPartitioner);
 
@@ -331,17 +331,17 @@ public void 
shouldNotThrowStreamsExceptionOnCloseIfASendFailedWithContinueExcept
     @Test(expected = StreamsException.class)
     public void shouldThrowIfTopicIsUnknownWithDefaultExceptionHandler() {
         final RecordCollector collector = new RecordCollectorImpl(
-            new MockProducer(cluster, true, new DefaultPartitioner(), 
byteArraySerializer, byteArraySerializer) {
-                @Override
-                public List<PartitionInfo> partitionsFor(final String topic) {
-                    return Collections.EMPTY_LIST;
-                }
-
-            },
             "test",
             logContext,
             new DefaultProductionExceptionHandler(),
             new Metrics().sensor("skipped-records"));
+        collector.init(new MockProducer(cluster, true, new 
DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
+            @Override
+            public List<PartitionInfo> partitionsFor(final String topic) {
+                return Collections.EMPTY_LIST;
+            }
+
+        });
         collector.send("topic1", "3", "0", null, null, stringSerializer, 
stringSerializer, streamPartitioner);
     }
 
@@ -349,17 +349,17 @@ public void 
shouldThrowIfTopicIsUnknownWithDefaultExceptionHandler() {
     @Test(expected = StreamsException.class)
     public void shouldThrowIfTopicIsUnknownWithContinueExceptionHandler() {
         final RecordCollector collector = new RecordCollectorImpl(
-            new MockProducer(cluster, true, new DefaultPartitioner(), 
byteArraySerializer, byteArraySerializer) {
-                @Override
-                public List<PartitionInfo> partitionsFor(final String topic) {
-                    return Collections.EMPTY_LIST;
-                }
-
-            },
             "test",
             logContext,
             new AlwaysContinueProductionExceptionHandler(),
             new Metrics().sensor("skipped-records"));
+        collector.init(new MockProducer(cluster, true, new 
DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
+            @Override
+            public List<PartitionInfo> partitionsFor(final String topic) {
+                return Collections.EMPTY_LIST;
+            }
+
+        });
         collector.send("topic1", "3", "0", null, null, stringSerializer, 
stringSerializer, streamPartitioner);
     }
 }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
index cf1d63f1e1e..a8cd2c8f357 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
@@ -61,7 +61,6 @@
     final InternalMockProcessorContext context = new 
InternalMockProcessorContext(
         StateSerdes.withBuiltinTypes("anyName", Bytes.class, Bytes.class),
         new RecordCollectorImpl(
-            null,
             null,
             new LogContext("record-queue-test "),
             new DefaultProductionExceptionHandler(),
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/SinkNodeTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/SinkNodeTest.java
index dacc17e86e7..269983f6380 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/SinkNodeTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/SinkNodeTest.java
@@ -37,20 +37,22 @@
 public class SinkNodeTest {
     private final Serializer<byte[]> anySerializer = 
Serdes.ByteArray().serializer();
     private final StateSerdes<Bytes, Bytes> anyStateSerde = 
StateSerdes.withBuiltinTypes("anyName", Bytes.class, Bytes.class);
+    private final RecordCollector recordCollector =  new RecordCollectorImpl(
+        null,
+        new LogContext("sinknode-test "),
+        new DefaultProductionExceptionHandler(),
+        new Metrics().sensor("skipped-records")
+    );
+
     private final InternalMockProcessorContext context = new 
InternalMockProcessorContext(
         anyStateSerde,
-        new RecordCollectorImpl(
-            new MockProducer<>(true, anySerializer, anySerializer),
-            null,
-            new LogContext("sinknode-test "),
-            new DefaultProductionExceptionHandler(),
-            new Metrics().sensor("skipped-records")
-        )
+        recordCollector
     );
     private final SinkNode sink = new SinkNode<>("anyNodeName", new 
StaticTopicNameExtractor("any-output-topic"), anySerializer, anySerializer, 
null);
 
     @Before
     public void before() {
+        recordCollector.init(new MockProducer<>(true, anySerializer, 
anySerializer));
         sink.init(context);
     }
 
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index 8f25c5304ce..39d654832bf 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -22,6 +22,7 @@
 import org.apache.kafka.clients.producer.MockProducer;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.ProducerFencedException;
 import org.apache.kafka.common.metrics.KafkaMetric;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.metrics.Sensor;
@@ -37,6 +38,7 @@
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.DefaultProductionExceptionHandler;
 import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.processor.PunctuationType;
 import org.apache.kafka.streams.processor.Punctuator;
 import org.apache.kafka.streams.processor.StateRestoreListener;
@@ -116,7 +118,7 @@ public void close() {
     );
 
     private final MockConsumer<byte[], byte[]> consumer = new 
MockConsumer<>(OffsetResetStrategy.EARLIEST);
-    private final MockProducer<byte[], byte[]> producer = new 
MockProducer<>(false, bytesSerializer, bytesSerializer);
+    private MockProducer<byte[], byte[]> producer;
     private final MockConsumer<byte[], byte[]> restoreStateConsumer = new 
MockConsumer<>(OffsetResetStrategy.EARLIEST);
     private final StateRestoreListener stateRestoreListener = new 
MockStateRestoreListener();
     private final StoreChangelogReader changelogReader = new 
StoreChangelogReader(restoreStateConsumer, Duration.ZERO, stateRestoreListener, 
new LogContext("stream-task-test ")) {
@@ -551,7 +553,7 @@ public void shouldPunctuateOnceSystemTimeAfterGap() {
 
     @Test
     public void shouldWrapKafkaExceptionsWithStreamsExceptionAndAddContext() {
-        task = createTaskThatThrowsException();
+        task = createTaskThatThrowsException(false);
         task.initializeStateStores();
         task.initializeTopology();
         task.addRecords(partition2, 
singletonList(getConsumerRecord(partition2, 0)));
@@ -621,7 +623,7 @@ public void shouldFlushRecordCollectorOnFlushState() {
             stateDirectory,
             null,
             time,
-            producer,
+            () -> producer = new MockProducer<>(false, bytesSerializer, 
bytesSerializer),
             new NoOpRecordCollector() {
                 @Override
                 public void flush() {
@@ -717,15 +719,178 @@ public void punctuate(final long timestamp) {
         });
     }
 
+    @Test
+    public void shouldNotCloseProducerOnCleanCloseWithEosDisabled() {
+        task = createStatelessTask(createConfig(false));
+        task.close(true, false);
+        task = null;
+
+        assertFalse(producer.closed());
+    }
+
+    @Test
+    public void shouldNotCloseProducerOnUncleanCloseWithEosDisabled() {
+        task = createStatelessTask(createConfig(false));
+        task.close(false, false);
+        task = null;
+
+        assertFalse(producer.closed());
+    }
+
+    @Test
+    public void shouldNotCloseProducerOnErrorDuringCleanCloseWithEosDisabled() 
{
+        task = createTaskThatThrowsException(false);
+
+        try {
+            task.close(true, false);
+            fail("should have thrown runtime exception");
+        } catch (final RuntimeException expected) {
+            task = null;
+        }
+
+        assertFalse(producer.closed());
+    }
+
+    @Test
+    public void 
shouldNotCloseProducerOnErrorDuringUncleanCloseWithEosDisabled() {
+        task = createTaskThatThrowsException(false);
+
+        task.close(false, false);
+        task = null;
+
+        assertFalse(producer.closed());
+    }
+
+    @Test
+    public void 
shouldCommitTransactionAndCloseProducerOnCleanCloseWithEosEnabled() {
+        task = createStatelessTask(createConfig(true));
+        task.initializeTopology();
+
+        task.close(true, false);
+        task = null;
+
+        assertTrue(producer.transactionCommitted());
+        assertFalse(producer.transactionInFlight());
+        assertTrue(producer.closed());
+    }
+
+    @Test
+    public void 
shouldNotAbortTransactionAndNotCloseProducerOnErrorDuringCleanCloseWithEosEnabled()
 {
+        task = createTaskThatThrowsException(true);
+        task.initializeTopology();
+
+        try {
+            task.close(true, false);
+            fail("should have thrown runtime exception");
+        } catch (final RuntimeException expected) {
+            task = null;
+        }
+
+        assertTrue(producer.transactionInFlight());
+        assertFalse(producer.closed());
+    }
+
+    @Test
+    public void 
shouldOnlyCloseProducerIfFencedOnCommitDuringCleanCloseWithEosEnabled() {
+        task = createStatelessTask(createConfig(true));
+        task.initializeTopology();
+        producer.fenceProducer();
+
+        try {
+            task.close(true, false);
+            fail("should have thrown TaskMigratedException");
+        } catch (final TaskMigratedException expected) {
+            task = null;
+            assertTrue(expected.getCause() instanceof ProducerFencedException);
+        }
+
+        assertFalse(producer.transactionCommitted());
+        assertTrue(producer.transactionInFlight());
+        assertFalse(producer.transactionAborted());
+        assertFalse(producer.transactionCommitted());
+        assertTrue(producer.closed());
+    }
+
+    @Test
+    public void 
shouldNotCloseProducerIfFencedOnCloseDuringCleanCloseWithEosEnabled() {
+        task = createStatelessTask(createConfig(true));
+        task.initializeTopology();
+        producer.fenceProducerOnClose();
+
+        try {
+            task.close(true, false);
+            fail("should have thrown TaskMigratedException");
+        } catch (final TaskMigratedException expected) {
+            task = null;
+            assertTrue(expected.getCause() instanceof ProducerFencedException);
+        }
+
+        assertTrue(producer.transactionCommitted());
+        assertFalse(producer.transactionInFlight());
+        assertFalse(producer.closed());
+    }
+
+    @Test
+    public void 
shouldAbortTransactionAndCloseProducerOnUncleanCloseWithEosEnabled() {
+        task = createStatelessTask(createConfig(true));
+        task.initializeTopology();
+
+        task.close(false, false);
+        task = null;
+
+        assertTrue(producer.transactionAborted());
+        assertFalse(producer.transactionInFlight());
+        assertTrue(producer.closed());
+    }
+
+    @Test
+    public void 
shouldAbortTransactionAndCloseProducerOnErrorDuringUncleanCloseWithEosEnabled() 
{
+        task = createTaskThatThrowsException(true);
+        task.initializeTopology();
+
+        task.close(false, false);
+
+        assertTrue(producer.transactionAborted());
+        assertTrue(producer.closed());
+    }
+
+    @Test
+    public void 
shouldOnlyCloseProducerIfFencedOnAbortDuringUncleanCloseWithEosEnabled() {
+        task = createStatelessTask(createConfig(true));
+        task.initializeTopology();
+        producer.fenceProducer();
+
+        task.close(false, false);
+        task = null;
+
+        assertTrue(producer.transactionInFlight());
+        assertFalse(producer.transactionAborted());
+        assertFalse(producer.transactionCommitted());
+        assertTrue(producer.closed());
+    }
+
+    @Test
+    public void 
shouldAbortTransactionButNotCloseProducerIfFencedOnCloseDuringUncleanCloseWithEosEnabled()
 {
+        task = createStatelessTask(createConfig(true));
+        task.initializeTopology();
+        producer.fenceProducerOnClose();
+
+        task.close(false, false);
+        task = null;
+
+        assertTrue(producer.transactionAborted());
+        assertFalse(producer.closed());
+    }
+
     @Test
     public void 
shouldThrowExceptionIfAnyExceptionsRaisedDuringCloseButStillCloseAllProcessorNodesTopology()
 {
-        task = createTaskThatThrowsException();
+        task = createTaskThatThrowsException(false);
         task.initializeStateStores();
         task.initializeTopology();
         try {
             task.close(true, false);
             fail("should have thrown runtime exception");
-        } catch (final RuntimeException e) {
+        } catch (final RuntimeException expected) {
             task = null;
         }
         assertTrue(processorSystemTime.closed);
@@ -742,6 +907,19 @@ public void 
shouldInitAndBeginTransactionOnCreateIfEosEnabled() {
         assertTrue(producer.transactionInFlight());
     }
 
+    @Test
+    public void 
shouldWrapProducerFencedExceptionWithTaskMigragedExceptionForBeginTransaction() 
{
+        task = createStatelessTask(createConfig(true));
+        producer.fenceProducer();
+
+        try {
+            task.initializeTopology();
+            fail("Should have throws TaskMigratedException");
+        } catch (final TaskMigratedException expected) {
+            assertTrue(expected.getCause() instanceof ProducerFencedException);
+        }
+    }
+
     @Test
     public void shouldNotThrowOnCloseIfTaskWasNotInitializedWithEosEnabled() {
         task = createStatelessTask(createConfig(true));
@@ -794,6 +972,37 @@ public void 
shouldNotSendOffsetsAndCommitTransactionNorStartNewTransactionOnSusp
         assertFalse(producer.transactionInFlight());
     }
 
+    @Test
+    public void 
shouldWrapProducerFencedExceptionWithTaskMigragedExceptionInSuspendWhenCommitting()
 {
+        task = createStatelessTask(createConfig(true));
+        producer.fenceProducer();
+
+        try {
+            task.suspend();
+            fail("Should have throws TaskMigratedException");
+        } catch (final TaskMigratedException expected) {
+            assertTrue(expected.getCause() instanceof ProducerFencedException);
+        }
+
+        assertFalse(producer.transactionCommitted());
+    }
+
+    @Test
+    public void 
shouldWrapProducerFencedExceptionWithTaskMigragedExceptionInSuspendWhenClosingProducer()
 {
+        task = createStatelessTask(createConfig(true));
+        task.initializeTopology();
+
+        producer.fenceProducerOnClose();
+        try {
+            task.suspend();
+            fail("Should have throws TaskMigratedException");
+        } catch (final TaskMigratedException expected) {
+            assertTrue(expected.getCause() instanceof ProducerFencedException);
+        }
+
+        assertTrue(producer.transactionCommitted());
+    }
+
     @Test
     public void shouldStartNewTransactionOnResumeIfEosEnabled() {
         task = createStatelessTask(createConfig(true));
@@ -843,16 +1052,6 @@ public void 
shouldNotStartNewTransactionOnCommitIfEosDisabled() {
         assertFalse(producer.transactionInFlight());
     }
 
-    @Test
-    public void shouldAbortTransactionOnDirtyClosedIfEosEnabled() {
-        task = createStatelessTask(createConfig(true));
-        task.initializeTopology();
-        task.close(false, false);
-        task = null;
-
-        assertTrue(producer.transactionAborted());
-    }
-
     @Test
     public void shouldNotAbortTransactionOnZombieClosedIfEosEnabled() {
         task = createStatelessTask(createConfig(true));
@@ -883,7 +1082,7 @@ public void shouldCloseProducerOnCloseWhenEosEnabled() {
 
     @Test
     public void shouldNotViolateAtLeastOnceWhenExceptionOccursDuringFlushing() 
{
-        task = createTaskThatThrowsException();
+        task = createTaskThatThrowsException(false);
         task.initializeStateStores();
         task.initializeTopology();
 
@@ -897,7 +1096,7 @@ public void 
shouldNotViolateAtLeastOnceWhenExceptionOccursDuringFlushing() {
 
     @Test
     public void 
shouldNotViolateAtLeastOnceWhenExceptionOccursDuringTaskSuspension() {
-        final StreamTask task = createTaskThatThrowsException();
+        final StreamTask task = createTaskThatThrowsException(false);
 
         task.initializeStateStores();
         task.initializeTopology();
@@ -928,7 +1127,7 @@ public void shouldCloseStateManagerIfFailureOnTaskClose() {
 
     @Test
     public void shouldNotCloseTopologyProcessorNodesIfNotInitialized() {
-        final StreamTask task = createTaskThatThrowsException();
+        final StreamTask task = createTaskThatThrowsException(false);
         try {
             task.close(false, false);
         } catch (final Exception e) {
@@ -972,7 +1171,7 @@ public void 
shouldReturnOffsetsForRepartitionTopicsForPurging() {
             stateDirectory,
             null,
             time,
-            producer,
+            () -> producer = new MockProducer<>(false, bytesSerializer, 
bytesSerializer),
             metrics.sensor("dummy"));
         task.initializeStateStores();
         task.initializeTopology();
@@ -1006,10 +1205,12 @@ public void 
shouldThrowOnCleanCloseTaskWhenEosEnabledIfTransactionInFlight() {
 
     @Test
     public void shouldAlwaysCommitIfEosEnabled() {
-        final RecordCollectorImpl recordCollector =  new 
RecordCollectorImpl(producer, "StreamTask",
+        task = createStatelessTask(createConfig(true));
+
+        final RecordCollectorImpl recordCollector =  new 
RecordCollectorImpl("StreamTask",
                 new LogContext("StreamTaskTest "), new 
DefaultProductionExceptionHandler(), new Metrics().sensor("skipped-records"));
+        recordCollector.init(producer);
 
-        task = createStatelessTask(createConfig(true));
         task.initializeStateStores();
         task.initializeTopology();
         task.punctuate(processorSystemTime, 5, 
PunctuationType.WALL_CLOCK_TIME, new Punctuator() {
@@ -1041,7 +1242,7 @@ private StreamTask createStatefulTask(final StreamsConfig 
config, final boolean
             stateDirectory,
             null,
             time,
-            producer,
+            () -> producer = new MockProducer<>(false, bytesSerializer, 
bytesSerializer),
             metrics.sensor("dummy"));
     }
 
@@ -1063,7 +1264,7 @@ private StreamTask 
createStatefulTaskThatThrowsExceptionOnClose() {
             stateDirectory,
             null,
             time,
-            producer,
+            () -> producer = new MockProducer<>(false, bytesSerializer, 
bytesSerializer),
             metrics.sensor("dummy"));
     }
 
@@ -1089,12 +1290,12 @@ private StreamTask createStatelessTask(final 
StreamsConfig streamsConfig) {
             stateDirectory,
             null,
             time,
-            producer,
+            () -> producer = new MockProducer<>(false, bytesSerializer, 
bytesSerializer),
             metrics.sensor("dummy"));
     }
 
     // this task will throw exception when processing (on partition2), 
flushing, suspending and closing
-    private StreamTask createTaskThatThrowsException() {
+    private StreamTask createTaskThatThrowsException(final boolean enableEos) {
         final ProcessorTopology topology = ProcessorTopology.withSources(
             Utils.<ProcessorNode>mkList(source1, source3, processorStreamTime, 
processorSystemTime),
             mkMap(mkEntry(topic1, (SourceNode) source1), mkEntry(topic2, 
(SourceNode) source3))
@@ -1111,12 +1312,12 @@ private StreamTask createTaskThatThrowsException() {
             topology,
             consumer,
             changelogReader,
-            createConfig(false),
+            createConfig(enableEos),
             streamsMetrics,
             stateDirectory,
             null,
             time,
-            producer,
+            () -> producer = new MockProducer<>(false, bytesSerializer, 
bytesSerializer),
             metrics.sensor("dummy")) {
             @Override
             protected void flushState() {
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 6056b2d1f56..c1485fb8056 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -690,8 +690,7 @@ public boolean conditionMet() {
         assertThat(producer.commitCount(), equalTo(2L));
     }
 
-    @Test
-    public void 
shouldCloseTaskAsZombieAndRemoveFromActiveTasksIfProducerGotFencedAtBeginTransactionWhenTaskIsResumed()
 {
+    private StreamThread setupStreamThread() {
         internalTopologyBuilder.addSource(null, "name", null, null, null, 
topic1);
         internalTopologyBuilder.addSink("out", "output", null, null, null, 
"name");
 
@@ -717,15 +716,32 @@ public void 
shouldCloseTaskAsZombieAndRemoveFromActiveTasksIfProducerGotFencedAt
         thread.runOnce(-1);
 
         assertThat(thread.tasks().size(), equalTo(1));
+        return thread;
+    }
+
+    @Test
+    public void 
shouldCloseTaskAsZombieAndRemoveFromActiveTasksIfProducerGotFencedInCommitTransactionWhenSuspendingTaks()
 {
+        final StreamThread thread = setupStreamThread();
 
-        thread.rebalanceListener.onPartitionsRevoked(null);
         clientSupplier.producers.get(0).fenceProducer();
-        thread.rebalanceListener.onPartitionsAssigned(assignedPartitions);
-        try {
-            thread.runOnce(-1);
-            fail("Should have thrown TaskMigratedException");
-        } catch (final TaskMigratedException expected) { /* ignore */ }
+        thread.rebalanceListener.onPartitionsRevoked(null);
+
+        assertTrue(clientSupplier.producers.get(0).transactionInFlight());
+        assertFalse(clientSupplier.producers.get(0).transactionCommitted());
+        assertTrue(clientSupplier.producers.get(0).closed());
+        assertTrue(thread.tasks().isEmpty());
+    }
+
+    @Test
+    public void 
shouldCloseTaskAsZombieAndRemoveFromActiveTasksIfProducerGotFencedInCloseTransactionWhenSuspendingTaks()
 {
+        final StreamThread thread = setupStreamThread();
+
+        clientSupplier.producers.get(0).fenceProducerOnClose();
+        thread.rebalanceListener.onPartitionsRevoked(null);
 
+        assertFalse(clientSupplier.producers.get(0).transactionInFlight());
+        assertTrue(clientSupplier.producers.get(0).transactionCommitted());
+        assertFalse(clientSupplier.producers.get(0).closed());
         assertTrue(thread.tasks().isEmpty());
     }
 
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
index 21049b1068d..699963395e9 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
@@ -190,7 +190,6 @@ private KeyValueStoreTestDriver(final StateSerdes<K, V> 
serdes) {
         final Producer<byte[], byte[]> producer = new MockProducer<>(true, 
rawSerializer, rawSerializer);
 
         final RecordCollector recordCollector = new RecordCollectorImpl(
-            producer,
             "KeyValueStoreTestDriver",
             new LogContext("KeyValueStoreTestDriver "),
             new DefaultProductionExceptionHandler(),
@@ -225,6 +224,7 @@ private KeyValueStoreTestDriver(final StateSerdes<K, V> 
serdes) {
                 throw new UnsupportedOperationException();
             }
         };
+        recordCollector.init(producer);
 
         final File stateDir = TestUtils.tempDirectory();
         //noinspection ResultOfMethodCallIgnored
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
index b0057e5495f..348d1be697d 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
@@ -41,6 +41,7 @@
 import org.apache.kafka.test.StreamsTestUtils;
 import org.apache.kafka.test.TestUtils;
 import org.junit.After;
+import org.junit.Before;
 import org.junit.Test;
 
 import java.io.File;
@@ -79,7 +80,6 @@
 
     private final Producer<byte[], byte[]> producer = new MockProducer<>(true, 
Serdes.ByteArray().serializer(), Serdes.ByteArray().serializer());
     private final RecordCollector recordCollector = new RecordCollectorImpl(
-        producer,
         "RocksDBWindowStoreTestTask",
         new LogContext("RocksDBWindowStoreTestTask "),
         new DefaultProductionExceptionHandler(),
@@ -115,6 +115,11 @@
         return store;
     }
 
+    @Before
+    public void initRecordCollector() {
+        recordCollector.init(producer);
+    }
+
     @After
     public void closeStore() {
         if (windowStore != null) {
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/StoreChangeLoggerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/StoreChangeLoggerTest.java
index 5afe14f8a0a..7186c28b0fa 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/StoreChangeLoggerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/StoreChangeLoggerTest.java
@@ -43,7 +43,7 @@
     private final Map<Integer, Headers> loggedHeaders = new HashMap<>();
 
     private final InternalMockProcessorContext context = new 
InternalMockProcessorContext(StateSerdes.withBuiltinTypes(topic, Integer.class, 
String.class),
-        new RecordCollectorImpl(null, "StoreChangeLoggerTest", new 
LogContext("StoreChangeLoggerTest "), new DefaultProductionExceptionHandler(), 
new Metrics().sensor("skipped-records")) {
+        new RecordCollectorImpl("StoreChangeLoggerTest", new 
LogContext("StoreChangeLoggerTest "), new DefaultProductionExceptionHandler(), 
new Metrics().sensor("skipped-records")) {
             @Override
             public <K1, V1> void send(final String topic,
                                       final K1 key,
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
index 75bb21943ad..33dce97f376 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
@@ -124,7 +124,7 @@ public void cleanUp() throws IOException {
     public void shouldFindKeyValueStores() {
         mockThread(true);
         final List<ReadOnlyKeyValueStore<String, String>> kvStores =
-            provider.stores("kv-store", QueryableStoreTypes.<String, 
String>keyValueStore());
+            provider.stores("kv-store", QueryableStoreTypes.keyValueStore());
         assertEquals(2, kvStores.size());
     }
 
@@ -190,7 +190,7 @@ private StreamTask createStreamsTask(final StreamsConfig 
streamsConfig,
             stateDirectory,
             null,
             new MockTime(),
-            clientSupplier.getProducer(new HashMap<String, Object>()),
+            () -> clientSupplier.getProducer(new HashMap<>()),
             metrics.sensor("dummy")) {
             @Override
             protected void updateOffsetLimits() {}
diff --git a/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java 
b/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
index 698cdc7ff85..b83936b8df5 100644
--- a/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
+++ b/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
@@ -235,7 +235,7 @@ private ProcessorRecordContext createRecordContext(final 
String topicName, final
 
     private class MockRecordCollector extends RecordCollectorImpl {
         MockRecordCollector() {
-            super(null, "KStreamTestDriver", new LogContext("KStreamTestDriver 
"), new DefaultProductionExceptionHandler(), new 
Metrics().sensor("skipped-records"));
+            super("KStreamTestDriver", new LogContext("KStreamTestDriver "), 
new DefaultProductionExceptionHandler(), new 
Metrics().sensor("skipped-records"));
         }
 
         @Override
diff --git 
a/streams/src/test/java/org/apache/kafka/test/NoOpRecordCollector.java 
b/streams/src/test/java/org/apache/kafka/test/NoOpRecordCollector.java
index 893d3566c6a..07ba9b4b98c 100644
--- a/streams/src/test/java/org/apache/kafka/test/NoOpRecordCollector.java
+++ b/streams/src/test/java/org/apache/kafka/test/NoOpRecordCollector.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.test;
 
+import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.serialization.Serializer;
@@ -47,6 +48,9 @@
                             final Serializer<V> valueSerializer,
                             final StreamPartitioner<? super K, ? super V> 
partitioner) {}
 
+    @Override
+    public void init(final Producer<byte[], byte[]> producer) {}
+
     @Override
     public void flush() {}
 
diff --git 
a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
 
b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
index d2796db6f9c..12974ae1ebc 100644
--- 
a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
+++ 
b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
@@ -338,7 +338,7 @@ public void onRestoreEnd(final TopicPartition 
topicPartition, final String store
                 stateDirectory,
                 cache,
                 mockWallClockTime,
-                producer,
+                () -> producer,
                 metrics.sensor("dummy"));
             task.initializeStateStores();
             task.initializeTopology();
@@ -680,6 +680,10 @@ public void close() {
         stateDirectory.clean();
     }
 
+    private Producer<byte[], byte[]> get() {
+        return producer;
+    }
+
     static class MockTime implements Time {
         private final AtomicLong timeMs;
         private final AtomicLong highResTimeNs;


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


> Streams should be more fencing-sensitive during task suspension under EOS
> -------------------------------------------------------------------------
>
>                 Key: KAFKA-7285
>                 URL: https://issues.apache.org/jira/browse/KAFKA-7285
>             Project: Kafka
>          Issue Type: Improvement
>          Components: streams
>    Affects Versions: 0.11.0.3, 1.0.2, 1.1.1, 2.0.0
>            Reporter: Guozhang Wang
>            Assignee: Matthias J. Sax
>            Priority: Major
>
> When EOS is turned on, Streams did the following steps:
> 1. InitTxn in task creation.
> 2. BeginTxn in topology initialization.
> 3. AbortTxn in clean shutdown.
> 4. CommitTxn in commit(), which is called in suspend() as well.
> Now consider this situation, with two thread (Ta) and (Tb) and one task:
> 1. originally Ta owns the task, consumer generation is 1.
> 2. Ta is un-responsive to send heartbeats, and gets kicked out, a new 
> generation 2 is formed with Tb in it. The task is migrated to Tb while Ta 
> does not know.
> 3. Ta finally calls `consumer.poll` and was aware of the rebalance, it 
> re-joins the group, forming a new generation of 3. And during the rebalance 
> the leader decides to assign the task back to Ta.
> 4.a) Ta calls onPartitionRevoked on the task, suspending it and call commit. 
> However if there is no data ever sent since `BeginTxn`, this commit call will 
> become a no-op.
> 4.b) Ta then calls onPartitionAssigned on the task, resuming it, and then 
> calls BeginTxn. Then it was encountered a ProducerFencedException, 
> incorrectly.
> The root cause is that, Ta does not trigger InitTxn to claim "I'm the newest 
> for this txnId, and am going to fence everyone else with the same txnId", so 
> it was mistakenly treated as the old client than Tb.
> Note that this issue is not common, since we need to encounter a txn that did 
> not send any data at all to make its commitTxn call a no-op, and hence not 
> being fenced earlier on.
> One proposal for this issue is to close the producer and recreates a new one 
> in `suspend` after the commitTxn call succeeded and `startNewTxn` is false, 
> so that the new producer will always `initTxn` to fence others.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to