This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 7c539f90298 KAFKA-16448: Unify error-callback exception handling 
(#16745)
7c539f90298 is described below

commit 7c539f902983a76f14d0cd993e7d6dcbfdacd909
Author: Matthias J. Sax <matth...@confluent.io>
AuthorDate: Sat Aug 3 12:40:51 2024 -0700

    KAFKA-16448: Unify error-callback exception handling (#16745)
    
    Follow up code cleanup for KIP-1033.
    
    This PR unifies the handling of both error cases for exception handlers:
     - handler throws an exception
     - handler returns null
    
    The unification happens for all 5 handler cases:
     - deserialzation
     - production / serialization
     - production / send
     - processing
     - punctuation
    
    Reviewers:  Sebastien Viale <sebastien.vi...@michelin.com>, Loic Greffier 
<loic.greff...@michelin.com>, Bill Bejeck <b...@confluent.io>
---
 .../errors/DeserializationExceptionHandler.java    |   2 +-
 .../streams/errors/ProductionExceptionHandler.java |   2 +-
 .../internals/DefaultErrorHandlerContext.java      |  12 ++
 .../internals/FailedProcessingException.java       |   7 +-
 .../internals/GlobalStateManagerImpl.java          |   5 +-
 .../streams/processor/internals/ProcessorNode.java |  21 ++-
 .../processor/internals/RecordCollectorImpl.java   | 106 +++++++----
 .../processor/internals/RecordDeserializer.java    |  50 ++---
 .../streams/processor/internals/StreamTask.java    |  49 +++--
 .../streams/integration/EosIntegrationTest.java    |   5 +-
 .../integration/EosV2UpgradeIntegrationTest.java   |   6 +-
 .../ProcessingExceptionHandlerIntegrationTest.java |  30 +--
 .../processor/internals/RecordCollectorTest.java   | 194 +++++++++++++++----
 .../internals/RecordDeserializerTest.java          | 210 ++++++++++++++-------
 .../processor/internals/StreamTaskTest.java        | 140 ++++++++++----
 15 files changed, 589 insertions(+), 250 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/errors/DeserializationExceptionHandler.java
 
b/streams/src/main/java/org/apache/kafka/streams/errors/DeserializationExceptionHandler.java
index 0d64611de67..198a97cce44 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/errors/DeserializationExceptionHandler.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/errors/DeserializationExceptionHandler.java
@@ -37,7 +37,7 @@ public interface DeserializationExceptionHandler extends 
Configurable {
      * @param context processor context
      * @param record record that failed deserialization
      * @param exception the actual exception
-     * @deprecated Since 3.9. Use Please {@link #handle(ErrorHandlerContext, 
ConsumerRecord, Exception)}
+     * @deprecated Since 3.9. Use {@link #handle(ErrorHandlerContext, 
ConsumerRecord, Exception)} instead.
      */
     @Deprecated
     default DeserializationHandlerResponse handle(final ProcessorContext 
context,
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/errors/ProductionExceptionHandler.java
 
b/streams/src/main/java/org/apache/kafka/streams/errors/ProductionExceptionHandler.java
index 25aa00f7a92..939b1ecbcd6 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/errors/ProductionExceptionHandler.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/errors/ProductionExceptionHandler.java
@@ -59,7 +59,7 @@ public interface ProductionExceptionHandler extends 
Configurable {
      *
      * @param record        the record that failed to serialize
      * @param exception     the exception that occurred during serialization
-     * @deprecated Since 3.9. Use {@link #handle(ErrorHandlerContext, 
ProducerRecord, Exception)} instead.
+     * @deprecated Since 3.9. Use {@link 
#handleSerializationException(ErrorHandlerContext, ProducerRecord, Exception, 
SerializationExceptionOrigin)} instead.
      */
     @Deprecated
     default ProductionExceptionHandlerResponse 
handleSerializationException(final ProducerRecord record,
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/errors/internals/DefaultErrorHandlerContext.java
 
b/streams/src/main/java/org/apache/kafka/streams/errors/internals/DefaultErrorHandlerContext.java
index c907ff3eb89..77500ce3c36 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/errors/internals/DefaultErrorHandlerContext.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/errors/internals/DefaultErrorHandlerContext.java
@@ -81,6 +81,18 @@ public class DefaultErrorHandlerContext implements 
ErrorHandlerContext {
         return taskId;
     }
 
+    @Override
+    public String toString() {
+        // we do exclude headers on purpose, to not accidentally log user data
+        return "ErrorHandlerContext{" +
+            "topic='" + topic + '\'' +
+            ", partition=" + partition +
+            ", offset=" + offset +
+            ", processorNodeId='" + processorNodeId + '\'' +
+            ", taskId=" + taskId +
+            '}';
+    }
+
     public Optional<ProcessorContext> processorContext() {
         return Optional.ofNullable(processorContext);
     }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/errors/internals/FailedProcessingException.java
 
b/streams/src/main/java/org/apache/kafka/streams/errors/internals/FailedProcessingException.java
index 25f2ae9f6cc..03d687d2687 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/errors/internals/FailedProcessingException.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/errors/internals/FailedProcessingException.java
@@ -25,7 +25,12 @@ import org.apache.kafka.streams.errors.StreamsException;
 public class FailedProcessingException extends StreamsException {
     private static final long serialVersionUID = 1L;
 
+    public FailedProcessingException(final String errorMessage, final 
Exception exception) {
+        super(errorMessage, exception);
+    }
+
     public FailedProcessingException(final Exception exception) {
-        super(exception);
+        // we need to explicitly set `message` to `null` here
+        super(null, exception);
     }
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
index 6b7214a9ed1..12d4c6c603d 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
@@ -319,7 +319,7 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
                                 record.headers()));
                             restoreCount++;
                         }
-                    } catch (final Exception deserializationException) {
+                    } catch (final RuntimeException deserializationException) {
                         handleDeserializationFailure(
                             deserializationExceptionHandler,
                             globalProcessorContext,
@@ -330,7 +330,8 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
                                 Thread.currentThread().getName(),
                                 globalProcessorContext.taskId().toString(),
                                 globalProcessorContext.metrics()
-                            )
+                            ),
+                            null
                         );
                     }
                 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java
index 3ccfcf24905..07cce0730ee 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java
@@ -38,6 +38,7 @@ import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
 
 import static 
org.apache.kafka.streams.StreamsConfig.PROCESSING_EXCEPTION_HANDLER_CLASS_CONFIG;
@@ -202,7 +203,7 @@ public class ProcessorNode<KIn, VIn, KOut, VOut> {
         } catch (final FailedProcessingException | TaskCorruptedException | 
TaskMigratedException e) {
             // Rethrow exceptions that should not be handled here
             throw e;
-        } catch (final RuntimeException e) {
+        } catch (final RuntimeException processingException) {
             final ErrorHandlerContext errorHandlerContext = new 
DefaultErrorHandlerContext(
                 null, // only required to pass for 
DeserializationExceptionHandler
                 internalProcessorContext.topic(),
@@ -213,18 +214,26 @@ public class ProcessorNode<KIn, VIn, KOut, VOut> {
                 internalProcessorContext.taskId());
 
             final ProcessingExceptionHandler.ProcessingHandlerResponse 
response;
-
             try {
-                response = 
processingExceptionHandler.handle(errorHandlerContext, record, e);
-            } catch (final Exception fatalUserException) {
-                throw new FailedProcessingException(fatalUserException);
+                response = Objects.requireNonNull(
+                    processingExceptionHandler.handle(errorHandlerContext, 
record, processingException),
+                    "Invalid ProductionExceptionHandler response."
+                );
+            } catch (final RuntimeException fatalUserException) {
+                log.error(
+                    "Processing error callback failed after processing error 
for record: {}",
+                    errorHandlerContext,
+                    processingException
+                );
+                throw new FailedProcessingException("Fatal user code error in 
processing error callback", fatalUserException);
             }
+
             if (response == 
ProcessingExceptionHandler.ProcessingHandlerResponse.FAIL) {
                 log.error("Processing exception handler is set to fail upon" +
                      " a processing error. If you would rather have the 
streaming pipeline" +
                      " continue after a processing error, please set the " +
                      PROCESSING_EXCEPTION_HANDLER_CLASS_CONFIG + " 
appropriately.");
-                throw new FailedProcessingException(e);
+                throw new FailedProcessingException(processingException);
             } else {
                 droppedRecordsSensor.record();
             }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
index a4dc0a68062..e471587ed0d 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
@@ -45,6 +45,7 @@ import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.errors.internals.DefaultErrorHandlerContext;
+import org.apache.kafka.streams.errors.internals.FailedProcessingException;
 import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
@@ -59,6 +60,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicReference;
@@ -208,7 +210,7 @@ public class RecordCollectorImpl implements RecordCollector 
{
                 key,
                 keySerializer,
                 exception);
-        } catch (final Exception exception) {
+        } catch (final RuntimeException serializationException) {
             handleException(
                 ProductionExceptionHandler.SerializationExceptionOrigin.KEY,
                 topic,
@@ -219,7 +221,7 @@ public class RecordCollectorImpl implements RecordCollector 
{
                 timestamp,
                 processorNodeId,
                 context,
-                exception);
+                serializationException);
             return;
         }
 
@@ -232,7 +234,7 @@ public class RecordCollectorImpl implements RecordCollector 
{
                 value,
                 valueSerializer,
                 exception);
-        } catch (final Exception exception) {
+        } catch (final RuntimeException serializationException) {
             handleException(
                 ProductionExceptionHandler.SerializationExceptionOrigin.VALUE,
                 topic,
@@ -243,7 +245,7 @@ public class RecordCollectorImpl implements RecordCollector 
{
                 timestamp,
                 processorNodeId,
                 context,
-                exception);
+                serializationException);
             return;
         }
 
@@ -297,42 +299,51 @@ public class RecordCollectorImpl implements 
RecordCollector {
                                         final Long timestamp,
                                         final String processorNodeId,
                                         final InternalProcessorContext<Void, 
Void> context,
-                                        final Exception exception) {
+                                        final RuntimeException 
serializationException) {
+        log.debug(String.format("Error serializing record for topic %s", 
topic), serializationException);
+
+        final DefaultErrorHandlerContext errorHandlerContext = new 
DefaultErrorHandlerContext(
+            null, // only required to pass for DeserializationExceptionHandler
+            context.recordContext().topic(),
+            context.recordContext().partition(),
+            context.recordContext().offset(),
+            context.recordContext().headers(),
+            processorNodeId,
+            taskId
+        );
         final ProducerRecord<K, V> record = new ProducerRecord<>(topic, 
partition, timestamp, key, value, headers);
-        final ProductionExceptionHandlerResponse response;
-
-        log.debug(String.format("Error serializing record to topic %s", 
topic), exception);
 
+        final ProductionExceptionHandlerResponse response;
         try {
-            final DefaultErrorHandlerContext errorHandlerContext = new 
DefaultErrorHandlerContext(
-                null, // only required to pass for 
DeserializationExceptionHandler
-                context.recordContext().topic(),
-                context.recordContext().partition(),
-                context.recordContext().offset(),
-                context.recordContext().headers(),
-                processorNodeId,
-                taskId
+            response = Objects.requireNonNull(
+                
productionExceptionHandler.handleSerializationException(errorHandlerContext, 
record, serializationException, origin),
+                "Invalid ProductionExceptionHandler response."
             );
-            response = 
productionExceptionHandler.handleSerializationException(errorHandlerContext, 
record, exception, origin);
-        } catch (final Exception e) {
-            log.error("Fatal when handling serialization exception", e);
-            recordSendError(topic, e, null, context, processorNodeId);
-            return;
+        } catch (final RuntimeException fatalUserException) {
+            log.error(
+                String.format(
+                    "Production error callback failed after serialization 
error for record %s: %s",
+                    origin.toString().toLowerCase(Locale.ROOT),
+                    errorHandlerContext
+                ),
+                serializationException
+            );
+            throw new FailedProcessingException("Fatal user code error in 
production error callback", fatalUserException);
         }
 
         if (response == ProductionExceptionHandlerResponse.FAIL) {
             throw new StreamsException(
                 String.format(
                     "Unable to serialize record. ProducerRecord(topic=[%s], 
partition=[%d], timestamp=[%d]",
-                        topic,
-                        partition,
-                        timestamp),
-                    exception
+                    topic,
+                    partition,
+                    timestamp),
+                serializationException
             );
         }
 
         log.warn("Unable to serialize record, continue processing. " +
-                        "ProducerRecord(topic=[{}], partition=[{}], 
timestamp=[{}])",
+                    "ProducerRecord(topic=[{}], partition=[{}], 
timestamp=[{}])",
                 topic,
                 partition,
                 timestamp);
@@ -364,24 +375,24 @@ public class RecordCollectorImpl implements 
RecordCollector {
     }
 
     private void recordSendError(final String topic,
-                                 final Exception exception,
+                                 final Exception productionException,
                                  final ProducerRecord<byte[], byte[]> 
serializedRecord,
                                  final InternalProcessorContext<Void, Void> 
context,
                                  final String processorNodeId) {
-        String errorMessage = String.format(SEND_EXCEPTION_MESSAGE, topic, 
taskId, exception.toString());
+        String errorMessage = String.format(SEND_EXCEPTION_MESSAGE, topic, 
taskId, productionException.toString());
 
-        if (isFatalException(exception)) {
+        if (isFatalException(productionException)) {
             errorMessage += "\nWritten offsets would not be recorded and no 
more records would be sent since this is a fatal error.";
-            sendException.set(new StreamsException(errorMessage, exception));
-        } else if (exception instanceof ProducerFencedException ||
-                exception instanceof InvalidPidMappingException ||
-                exception instanceof InvalidProducerEpochException ||
-                exception instanceof OutOfOrderSequenceException) {
+            sendException.set(new StreamsException(errorMessage, 
productionException));
+        } else if (productionException instanceof ProducerFencedException ||
+                productionException instanceof InvalidPidMappingException ||
+                productionException instanceof InvalidProducerEpochException ||
+                productionException instanceof OutOfOrderSequenceException) {
             errorMessage += "\nWritten offsets would not be recorded and no 
more records would be sent since the producer is fenced, " +
                 "indicating the task may be migrated out";
-            sendException.set(new TaskMigratedException(errorMessage, 
exception));
+            sendException.set(new TaskMigratedException(errorMessage, 
productionException));
         } else {
-            if (isRetriable(exception)) {
+            if (isRetriable(productionException)) {
                 errorMessage += "\nThe broker is either slow or in bad state 
(like not having enough replicas) in responding the request, " +
                     "or the connection to broker was interrupted sending the 
request or receiving the response. " +
                     "\nConsider overwriting `max.block.ms` and /or " +
@@ -398,17 +409,34 @@ public class RecordCollectorImpl implements 
RecordCollector {
                     taskId
                 );
 
-                if (productionExceptionHandler.handle(errorHandlerContext, 
serializedRecord, exception) == ProductionExceptionHandlerResponse.FAIL) {
+                final ProductionExceptionHandlerResponse response;
+                try {
+                    response = Objects.requireNonNull(
+                        productionExceptionHandler.handle(errorHandlerContext, 
serializedRecord, productionException),
+                        "Invalid ProductionExceptionHandler response."
+                    );
+                } catch (final RuntimeException fatalUserException) {
+                    log.error(
+                        "Production error callback failed after production 
error for record {}",
+                        serializedRecord,
+                        productionException
+                    );
+                    sendException.set(new FailedProcessingException("Fatal 
user code error in production error callback", fatalUserException));
+                    return;
+                }
+
+                if (response == ProductionExceptionHandlerResponse.FAIL) {
                     errorMessage += "\nException handler choose to FAIL the 
processing, no more records would be sent.";
-                    sendException.set(new StreamsException(errorMessage, 
exception));
+                    sendException.set(new StreamsException(errorMessage, 
productionException));
                 } else {
                     errorMessage += "\nException handler choose to CONTINUE 
processing in spite of this error but written offsets would not be recorded.";
                     droppedRecordsSensor.record();
                 }
+
             }
         }
 
-        log.error(errorMessage, exception);
+        log.error(errorMessage, productionException);
     }
 
     /**
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java
index 8ee2dc014eb..5fc03352ecc 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java
@@ -21,12 +21,14 @@ import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.record.TimestampType;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.errors.DeserializationExceptionHandler;
+import 
org.apache.kafka.streams.errors.DeserializationExceptionHandler.DeserializationHandlerResponse;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.internals.DefaultErrorHandlerContext;
 import org.apache.kafka.streams.processor.api.ProcessorContext;
 
 import org.slf4j.Logger;
 
+import java.util.Objects;
 import java.util.Optional;
 
 import static 
org.apache.kafka.streams.StreamsConfig.DEFAULT_DESERIALIZATION_EXCEPTION_HANDLER_CLASS_CONFIG;
@@ -49,7 +51,7 @@ public class RecordDeserializer {
 
     /**
      * @throws StreamsException if a deserialization error occurs and the 
deserialization callback returns
-     *                          {@link 
DeserializationExceptionHandler.DeserializationHandlerResponse#FAIL FAIL}
+     *                          {@link DeserializationHandlerResponse#FAIL 
FAIL}
      *                          or throws an exception itself
      */
     ConsumerRecord<Object, Object> deserialize(final ProcessorContext<?, ?> 
processorContext,
@@ -69,7 +71,7 @@ public class RecordDeserializer {
                 rawRecord.headers(),
                 Optional.empty()
             );
-        } catch (final Exception deserializationException) {
+        } catch (final RuntimeException deserializationException) {
             handleDeserializationFailure(deserializationExceptionHandler, 
processorContext, deserializationException, rawRecord, log, 
droppedRecordsSensor, sourceNode().name());
             return null; //  'handleDeserializationFailure' would either throw 
or swallow -- if we swallow we need to skip the record by returning 'null'
         }
@@ -77,39 +79,37 @@ public class RecordDeserializer {
 
     public static void handleDeserializationFailure(final 
DeserializationExceptionHandler deserializationExceptionHandler,
                                                     final ProcessorContext<?, 
?> processorContext,
-                                                    final Exception 
deserializationException,
-                                                    final 
ConsumerRecord<byte[], byte[]> rawRecord,
-                                                    final Logger log,
-                                                    final Sensor 
droppedRecordsSensor) {
-        handleDeserializationFailure(deserializationExceptionHandler, 
processorContext, deserializationException, rawRecord, log, 
droppedRecordsSensor, null);
-    }
-
-    public static void handleDeserializationFailure(final 
DeserializationExceptionHandler deserializationExceptionHandler,
-                                                    final ProcessorContext<?, 
?> processorContext,
-                                                    final Exception 
deserializationException,
+                                                    final RuntimeException 
deserializationException,
                                                     final 
ConsumerRecord<byte[], byte[]> rawRecord,
                                                     final Logger log,
                                                     final Sensor 
droppedRecordsSensor,
                                                     final String 
sourceNodeName) {
-        final DeserializationExceptionHandler.DeserializationHandlerResponse 
response;
+
+        final DefaultErrorHandlerContext errorHandlerContext = new 
DefaultErrorHandlerContext(
+            (InternalProcessorContext<?, ?>) processorContext,
+            rawRecord.topic(),
+            rawRecord.partition(),
+            rawRecord.offset(),
+            rawRecord.headers(),
+            sourceNodeName,
+            processorContext.taskId());
+
+        final DeserializationHandlerResponse response;
         try {
-            final DefaultErrorHandlerContext errorHandlerContext = new 
DefaultErrorHandlerContext(
-                (InternalProcessorContext<?, ?>) processorContext,
-                rawRecord.topic(),
-                rawRecord.partition(),
-                rawRecord.offset(),
-                rawRecord.headers(),
-                sourceNodeName,
-                processorContext.taskId());
-            response = 
deserializationExceptionHandler.handle(errorHandlerContext, rawRecord, 
deserializationException);
-        } catch (final Exception fatalUserException) {
+            response = Objects.requireNonNull(
+                deserializationExceptionHandler.handle(errorHandlerContext, 
rawRecord, deserializationException),
+                "Invalid DeserializationExceptionHandler response."
+            );
+        } catch (final RuntimeException fatalUserException) {
             log.error(
                 "Deserialization error callback failed after deserialization 
error for record {}",
                 rawRecord,
-                deserializationException);
+                deserializationException
+            );
             throw new StreamsException("Fatal user code error in 
deserialization error callback", fatalUserException);
         }
-        if (response == 
DeserializationExceptionHandler.DeserializationHandlerResponse.FAIL) {
+
+        if (response == DeserializationHandlerResponse.FAIL) {
             throw new StreamsException("Deserialization exception handler is 
set to fail upon" +
                 " a deserialization error. If you would rather have the 
streaming pipeline" +
                 " continue after a deserialization error, please set the " +
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index 6f2edd442b0..f08cfa7fd67 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -59,6 +59,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Optional;
 import java.util.Set;
 import java.util.function.Function;
@@ -807,7 +808,7 @@ public class StreamTask extends AbstractTask implements 
ProcessorNodePunctuator,
             }
         } catch (final FailedProcessingException failedProcessingException) {
             // Do not keep the failed processing exception in the stack trace
-            handleException(failedProcessingException.getCause());
+            handleException(failedProcessingException.getMessage(), 
failedProcessingException.getCause());
         } catch (final StreamsException exception) {
             record = null;
             throw exception;
@@ -820,19 +821,25 @@ public class StreamTask extends AbstractTask implements 
ProcessorNodePunctuator,
         return true;
     }
 
-    private void handleException(final Throwable e) {
-        final StreamsException error = new StreamsException(
+    private void handleException(final Throwable originalException) {
+        handleException(
             String.format(
-                "Exception caught in process. taskId=%s, processor=%s, 
topic=%s, partition=%d, offset=%d, stacktrace=%s",
+                "Exception caught in process. taskId=%s, processor=%s, 
topic=%s, partition=%d, offset=%d",
                 id(),
                 processorContext.currentNode().name(),
                 record.topic(),
                 record.partition(),
-                record.offset(),
-                getStacktraceString(e)
+                record.offset()
             ),
-            e
-        );
+            originalException);
+    }
+
+    private void handleException(final String errorMessage, final Throwable 
originalException) {
+        if (errorMessage == null) {
+            handleException(originalException);
+        }
+
+        final StreamsException error = new StreamsException(errorMessage, 
originalException);
         record = null;
 
         throw error;
@@ -920,11 +927,18 @@ public class StreamTask extends AbstractTask implements 
ProcessorNodePunctuator,
 
         try {
             maybeMeasureLatency(() -> punctuator.punctuate(timestamp), time, 
punctuateLatencySensor);
+        } catch (final TimeoutException timeoutException) {
+            if (!eosEnabled) {
+                throw timeoutException;
+            } else {
+                record = null;
+                throw new TaskCorruptedException(Collections.singleton(id));
+            }
         } catch (final FailedProcessingException e) {
             throw createStreamsException(node.name(), e.getCause());
         } catch (final TaskCorruptedException | TaskMigratedException e) {
             throw e;
-        } catch (final Exception e) {
+        } catch (final RuntimeException processingException) {
             final ErrorHandlerContext errorHandlerContext = new 
DefaultErrorHandlerContext(
                 null,
                 recordContext.topic(),
@@ -936,11 +950,18 @@ public class StreamTask extends AbstractTask implements 
ProcessorNodePunctuator,
             );
 
             final ProcessingExceptionHandler.ProcessingHandlerResponse 
response;
-
             try {
-                response = 
processingExceptionHandler.handle(errorHandlerContext, null, e);
-            } catch (final Exception fatalUserException) {
-                throw new FailedProcessingException(fatalUserException);
+                response = Objects.requireNonNull(
+                    processingExceptionHandler.handle(errorHandlerContext, 
null, processingException),
+                    "Invalid ProcessingExceptionHandler response."
+                );
+            } catch (final RuntimeException fatalUserException) {
+                log.error(
+                    "Processing error callback failed after processing error 
for record: {}",
+                    errorHandlerContext,
+                    processingException
+                );
+                throw new FailedProcessingException("Fatal user code error in 
processing error callback", fatalUserException);
             }
 
             if (response == 
ProcessingExceptionHandler.ProcessingHandlerResponse.FAIL) {
@@ -949,7 +970,7 @@ public class StreamTask extends AbstractTask implements 
ProcessorNodePunctuator,
                         " continue after a processing error, please set the " +
                         PROCESSING_EXCEPTION_HANDLER_CLASS_CONFIG + " 
appropriately.");
 
-                throw createStreamsException(node.name(), e);
+                throw createStreamsException(node.name(), processingException);
             } else {
                 droppedRecordsSensor.record();
             }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
index 94e48ee3d49..3b588439e92 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
@@ -39,6 +39,7 @@ import org.apache.kafka.streams.StreamsBuilder;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsConfig.InternalConfig;
 import org.apache.kafka.streams.Topology;
+import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
@@ -1170,7 +1171,9 @@ public class EosIntegrationTest {
         final KafkaStreams streams = new KafkaStreams(builder.build(), config);
 
         streams.setUncaughtExceptionHandler((t, e) -> {
-            if (uncaughtException != null || 
!e.getMessage().contains("Injected test exception")) {
+            if (uncaughtException != null ||
+                !(e instanceof StreamsException) ||
+                !e.getCause().getMessage().equals("Injected test exception.")) 
{
                 e.printStackTrace(System.err);
                 hasUnexpectedError = true;
             }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/integration/EosV2UpgradeIntegrationTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/integration/EosV2UpgradeIntegrationTest.java
index 8653f69fca6..4dd8e5697c4 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/integration/EosV2UpgradeIntegrationTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/integration/EosV2UpgradeIntegrationTest.java
@@ -35,6 +35,7 @@ import org.apache.kafka.streams.StoreQueryParameters;
 import org.apache.kafka.streams.StreamsBuilder;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsConfig.InternalConfig;
+import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler;
 import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
@@ -947,8 +948,9 @@ public class EosV2UpgradeIntegrationTest {
             } else {
                 int exceptionCount = exceptionCounts.get(appDir);
                 // should only have our injected exception or commit 
exception, and 2 exceptions for each stream
-                if (++exceptionCount > 2 || !(e instanceof RuntimeException) ||
-                    !(e.getMessage().contains("test exception"))) {
+                if (++exceptionCount > 2 ||
+                    !(e instanceof StreamsException) ||
+                    !(e.getCause().getMessage().endsWith(" test exception."))) 
{
                     // The exception won't cause the test fail since we 
actually "expected" exception thrown and failed the stream.
                     // So, log to stderr for debugging when the exception is 
not what we expected, and fail in the main thread
                     e.printStackTrace(System.err);
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/integration/ProcessingExceptionHandlerIntegrationTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/integration/ProcessingExceptionHandlerIntegrationTest.java
index 61b5ed16bb1..d0c32310550 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/integration/ProcessingExceptionHandlerIntegrationTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/integration/ProcessingExceptionHandlerIntegrationTest.java
@@ -51,6 +51,7 @@ import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
 import static org.junit.jupiter.api.Assertions.assertIterableEquals;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -61,7 +62,7 @@ public class ProcessingExceptionHandlerIntegrationTest {
     private final String threadId = Thread.currentThread().getName();
 
     @Test
-    public void shouldFailWhenProcessingExceptionOccurs() {
+    public void 
shouldFailWhenProcessingExceptionOccursIfExceptionHandlerReturnsFail() {
         final List<KeyValue<String, String>> events = Arrays.asList(
             new KeyValue<>("ID123-1", "ID123-A1"),
             new KeyValue<>("ID123-2-ERR", "ID123-A2"),
@@ -93,8 +94,7 @@ public class ProcessingExceptionHandlerIntegrationTest {
 
             assertTrue(exception.getMessage().contains("Exception caught in 
process. "
                 + "taskId=0_0, processor=KSTREAM-SOURCE-0000000000, 
topic=TOPIC_NAME, "
-                + "partition=0, offset=1, 
stacktrace=java.lang.RuntimeException: "
-                + "Exception should be handled by processing exception 
handler"));
+                + "partition=0, offset=1"));
             assertEquals(1, 
processor.theCapturedProcessor().processed().size());
             assertIterableEquals(expectedProcessedRecords, 
processor.theCapturedProcessor().processed());
 
@@ -107,7 +107,7 @@ public class ProcessingExceptionHandlerIntegrationTest {
     }
 
     @Test
-    public void shouldContinueWhenProcessingExceptionOccurs() {
+    public void 
shouldContinueWhenProcessingExceptionOccursIfExceptionHandlerReturnsContinue() {
         final List<KeyValue<String, String>> events = Arrays.asList(
             new KeyValue<>("ID123-1", "ID123-A1"),
             new KeyValue<>("ID123-2-ERR", "ID123-A2"),
@@ -182,8 +182,7 @@ public class ProcessingExceptionHandlerIntegrationTest {
             final StreamsException e = assertThrows(StreamsException.class, () 
-> inputTopic.pipeInput(eventError.key, eventError.value, Instant.EPOCH));
             assertTrue(e.getMessage().contains("Exception caught in process. "
                 + "taskId=0_0, processor=KSTREAM-SOURCE-0000000000, 
topic=TOPIC_NAME, "
-                + "partition=0, offset=1, 
stacktrace=java.lang.RuntimeException: "
-                + "Exception should be handled by processing exception 
handler"));
+                + "partition=0, offset=1"));
             assertFalse(isExecuted.get());
         }
     }
@@ -222,9 +221,9 @@ public class ProcessingExceptionHandlerIntegrationTest {
     }
 
     @Test
-    public void 
shouldStopProcessingWhenFatalUserExceptionInFailProcessingExceptionHandler() {
+    public void 
shouldStopProcessingWhenProcessingExceptionHandlerReturnsNull() {
         final KeyValue<String, String> event = new KeyValue<>("ID123-1", 
"ID123-A1");
-        final KeyValue<String, String> eventError = new 
KeyValue<>("ID123-ERR-FATAL", "ID123-A2");
+        final KeyValue<String, String> eventError = new 
KeyValue<>("ID123-ERR-NULL", "ID123-A2");
 
         final MockProcessorSupplier<String, String, Void, Void> processor = 
new MockProcessorSupplier<>();
         final StreamsBuilder builder = new StreamsBuilder();
@@ -241,7 +240,7 @@ public class ProcessingExceptionHandlerIntegrationTest {
             .process(processor);
 
         final Properties properties = new Properties();
-        
properties.put(StreamsConfig.PROCESSING_EXCEPTION_HANDLER_CLASS_CONFIG, 
FailProcessingExceptionHandlerMockTest.class);
+        
properties.put(StreamsConfig.PROCESSING_EXCEPTION_HANDLER_CLASS_CONFIG, 
ContinueProcessingExceptionHandlerMockTest.class);
 
         try (final TopologyTestDriver driver = new 
TopologyTestDriver(builder.build(), properties, Instant.ofEpochMilli(0L))) {
             final TestInputTopic<String, String> inputTopic = 
driver.createInputTopic("TOPIC_NAME", new StringSerializer(), new 
StringSerializer());
@@ -250,13 +249,15 @@ public class ProcessingExceptionHandlerIntegrationTest {
             assertTrue(isExecuted.get());
             isExecuted.set(false);
             final StreamsException e = assertThrows(StreamsException.class, () 
-> inputTopic.pipeInput(eventError.key, eventError.value, Instant.EPOCH));
-            assertEquals("KABOOM!", e.getCause().getMessage());
+            assertEquals("Fatal user code error in processing error callback", 
e.getMessage());
+            assertInstanceOf(NullPointerException.class, e.getCause());
+            assertEquals("Invalid ProductionExceptionHandler response.", 
e.getCause().getMessage());
             assertFalse(isExecuted.get());
         }
     }
 
     @Test
-    public void 
shouldStopProcessingWhenFatalUserExceptionInContinueProcessingExceptionHandler()
 {
+    public void 
shouldStopProcessingWhenFatalUserExceptionProcessingExceptionHandler() {
         final KeyValue<String, String> event = new KeyValue<>("ID123-1", 
"ID123-A1");
         final KeyValue<String, String> eventError = new 
KeyValue<>("ID123-ERR-FATAL", "ID123-A2");
 
@@ -284,6 +285,7 @@ public class ProcessingExceptionHandlerIntegrationTest {
             assertTrue(isExecuted.get());
             isExecuted.set(false);
             final StreamsException e = assertThrows(StreamsException.class, () 
-> inputTopic.pipeInput(eventError.key, eventError.value, Instant.EPOCH));
+            assertEquals("Fatal user code error in processing error callback", 
e.getMessage());
             assertEquals("KABOOM!", e.getCause().getMessage());
             assertFalse(isExecuted.get());
         }
@@ -295,6 +297,9 @@ public class ProcessingExceptionHandlerIntegrationTest {
             if (((String) record.key()).contains("FATAL")) {
                 throw new RuntimeException("KABOOM!");
             }
+            if (((String) record.key()).contains("NULL")) {
+                return null;
+            }
             assertProcessingExceptionHandlerInputs(context, record, exception);
             return 
ProcessingExceptionHandler.ProcessingHandlerResponse.CONTINUE;
         }
@@ -308,9 +313,6 @@ public class ProcessingExceptionHandlerIntegrationTest {
     public static class FailProcessingExceptionHandlerMockTest implements 
ProcessingExceptionHandler {
         @Override
         public ProcessingExceptionHandler.ProcessingHandlerResponse 
handle(final ErrorHandlerContext context, final Record<?, ?> record, final 
Exception exception) {
-            if (((String) record.key()).contains("FATAL")) {
-                throw new RuntimeException("KABOOM!");
-            }
             assertProcessingExceptionHandlerInputs(context, record, exception);
             return ProcessingExceptionHandler.ProcessingHandlerResponse.FAIL;
         }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
index 26231028851..ec8f6e5a9f9 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
@@ -51,6 +51,8 @@ import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.DefaultProductionExceptionHandler;
 import org.apache.kafka.streams.errors.ErrorHandlerContext;
 import org.apache.kafka.streams.errors.ProductionExceptionHandler;
+import 
org.apache.kafka.streams.errors.ProductionExceptionHandler.ProductionExceptionHandlerResponse;
+import 
org.apache.kafka.streams.errors.ProductionExceptionHandler.SerializationExceptionOrigin;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.processor.StreamPartitioner;
@@ -1242,7 +1244,7 @@ public class RecordCollectorTest {
             logContext,
             taskId,
             getExceptionalStreamsProducerOnSend(exception),
-            new 
ProductionExceptionHandlerMock(ProductionExceptionHandler.ProductionExceptionHandlerResponse.CONTINUE),
+            new 
ProductionExceptionHandlerMock(Optional.of(ProductionExceptionHandlerResponse.CONTINUE)),
             streamsMetrics,
             topology
         );
@@ -1269,7 +1271,7 @@ public class RecordCollectorTest {
             logContext,
             taskId,
             getExceptionalStreamsProducerOnSend(exception),
-            new 
ProductionExceptionHandlerMock(ProductionExceptionHandler.ProductionExceptionHandlerResponse.CONTINUE),
+            new 
ProductionExceptionHandlerMock(Optional.of(ProductionExceptionHandlerResponse.CONTINUE)),
             streamsMetrics,
             topology
         );
@@ -1293,7 +1295,7 @@ public class RecordCollectorTest {
             logContext,
             taskId,
             getExceptionalStreamsProducerOnSend(exception),
-            new 
ProductionExceptionHandlerMock(ProductionExceptionHandler.ProductionExceptionHandlerResponse.CONTINUE),
+            new 
ProductionExceptionHandlerMock(Optional.of(ProductionExceptionHandlerResponse.CONTINUE)),
             streamsMetrics,
             topology
         );
@@ -1317,7 +1319,7 @@ public class RecordCollectorTest {
             taskId,
             getExceptionalStreamsProducerOnSend(new 
RuntimeException("KABOOM!")),
             new ProductionExceptionHandlerMock(
-                
ProductionExceptionHandler.ProductionExceptionHandlerResponse.CONTINUE,
+                Optional.of(ProductionExceptionHandlerResponse.CONTINUE),
                 context,
                 sinkNodeName,
                 taskId
@@ -1363,12 +1365,12 @@ public class RecordCollectorTest {
     public void 
shouldThrowStreamsExceptionOnUnknownTopicOrPartitionExceptionWithDefaultExceptionHandler()
 {
         final KafkaException exception = new TimeoutException("KABOOM!", new 
UnknownTopicOrPartitionException());
         final RecordCollector collector = new RecordCollectorImpl(
-                logContext,
-                taskId,
-                getExceptionalStreamsProducerOnSend(exception),
-                productionExceptionHandler,
-                streamsMetrics,
-                topology
+            logContext,
+            taskId,
+            getExceptionalStreamsProducerOnSend(exception),
+            productionExceptionHandler,
+            streamsMetrics,
+            topology
         );
 
         collector.send(topic, "3", "0", null, null, stringSerializer, 
stringSerializer, sinkNodeName, context, streamPartitioner);
@@ -1378,10 +1380,10 @@ public class RecordCollectorTest {
         final StreamsException thrown = assertThrows(StreamsException.class, 
collector::flush);
         assertEquals(exception, thrown.getCause());
         assertThat(
-                thrown.getMessage(),
-                equalTo("Error encountered sending record to topic topic for 
task 0_0 due to:" +
-                        "\norg.apache.kafka.common.errors.TimeoutException: 
KABOOM!" +
-                        "\nException handler choose to FAIL the processing, no 
more records would be sent.")
+            thrown.getMessage(),
+            equalTo("Error encountered sending record to topic topic for task 
0_0 due to:" +
+                    "\norg.apache.kafka.common.errors.TimeoutException: 
KABOOM!" +
+                    "\nException handler choose to FAIL the processing, no 
more records would be sent.")
         );
     }
 
@@ -1389,17 +1391,17 @@ public class RecordCollectorTest {
     public void 
shouldNotThrowTaskCorruptedExceptionOnUnknownTopicOrPartitionExceptionUsingAlwaysContinueExceptionHandler()
 {
         final KafkaException exception = new TimeoutException("KABOOM!", new 
UnknownTopicOrPartitionException());
         final RecordCollector collector = new RecordCollectorImpl(
-                logContext,
-                taskId,
-                getExceptionalStreamsProducerOnSend(exception),
-                new ProductionExceptionHandlerMock(
-                    
ProductionExceptionHandler.ProductionExceptionHandlerResponse.CONTINUE,
-                    context,
-                    sinkNodeName,
-                    taskId
-                ),
-                streamsMetrics,
-                topology
+            logContext,
+            taskId,
+            getExceptionalStreamsProducerOnSend(exception),
+            new ProductionExceptionHandlerMock(
+                Optional.of(ProductionExceptionHandlerResponse.CONTINUE),
+                context,
+                sinkNodeName,
+                taskId
+            ),
+            streamsMetrics,
+            topology
         );
 
         collector.send(topic, "3", "0", null, null, stringSerializer, 
stringSerializer, sinkNodeName, context, streamPartitioner);
@@ -1539,11 +1541,11 @@ public class RecordCollectorTest {
     public void shouldDropRecordExceptionUsingAlwaysContinueExceptionHandler() 
{
         try (final ErrorStringSerializer errorSerializer = new 
ErrorStringSerializer()) {
             final RecordCollector collector = newRecordCollector(new 
ProductionExceptionHandlerMock(
-                
ProductionExceptionHandler.ProductionExceptionHandlerResponse.CONTINUE,
+                Optional.of(ProductionExceptionHandlerResponse.CONTINUE),
                 context,
                 sinkNodeName,
                 taskId,
-                ProductionExceptionHandler.SerializationExceptionOrigin.KEY
+                SerializationExceptionOrigin.KEY
             ));
             collector.initialize();
 
@@ -1568,11 +1570,11 @@ public class RecordCollectorTest {
     public void 
shouldThrowStreamsExceptionWhenValueSerializationFailedAndProductionExceptionHandlerRepliesWithFail()
 {
         try (final ErrorStringSerializer errorSerializer = new 
ErrorStringSerializer()) {
             final RecordCollector collector = newRecordCollector(new 
ProductionExceptionHandlerMock(
-                
ProductionExceptionHandler.ProductionExceptionHandlerResponse.FAIL,
+                Optional.of(ProductionExceptionHandlerResponse.FAIL),
                 context,
                 sinkNodeName,
                 taskId,
-                ProductionExceptionHandler.SerializationExceptionOrigin.VALUE
+                SerializationExceptionOrigin.VALUE
             ));
             collector.initialize();
 
@@ -1589,11 +1591,11 @@ public class RecordCollectorTest {
     public void 
shouldThrowStreamsExceptionWhenKeySerializationFailedAndProductionExceptionHandlerRepliesWithFail()
 {
         try (final ErrorStringSerializer errorSerializer = new 
ErrorStringSerializer()) {
             final RecordCollector collector = newRecordCollector(new 
ProductionExceptionHandlerMock(
-                
ProductionExceptionHandler.ProductionExceptionHandlerResponse.FAIL,
+                Optional.of(ProductionExceptionHandlerResponse.FAIL),
                 context,
                 sinkNodeName,
                 taskId,
-                ProductionExceptionHandler.SerializationExceptionOrigin.KEY
+                SerializationExceptionOrigin.KEY
             ));
             collector.initialize();
 
@@ -1606,11 +1608,109 @@ public class RecordCollectorTest {
         }
     }
 
+    @Test
+    public void 
shouldThrowStreamsExceptionWhenSerializationFailedAndProductionExceptionHandlerReturnsNull()
 {
+        try (final ErrorStringSerializer errorSerializer = new 
ErrorStringSerializer()) {
+            final RecordCollector collector = newRecordCollector(new 
ProductionExceptionHandlerMock(
+                Optional.empty(),
+                context,
+                sinkNodeName,
+                taskId,
+                SerializationExceptionOrigin.KEY
+            ));
+            collector.initialize();
+
+            final StreamsException exception = assertThrows(
+                StreamsException.class,
+                () -> collector.send(topic, "key", "val", null, 0, null, 
errorSerializer, stringSerializer, sinkNodeName, context)
+            );
+
+            assertEquals("Fatal user code error in production error callback", 
exception.getMessage());
+            assertInstanceOf(NullPointerException.class, exception.getCause());
+            assertEquals("Invalid ProductionExceptionHandler response.", 
exception.getCause().getMessage());
+        }
+    }
+
+    @Test
+    public void 
shouldThrowStreamsExceptionWhenSerializationFailedAndProductionExceptionHandlerThrows()
 {
+        try (final ErrorStringSerializer errorSerializer = new 
ErrorStringSerializer()) {
+            final RecordCollector collector = newRecordCollector(new 
ProductionExceptionHandlerMock(
+                true,
+                context,
+                sinkNodeName,
+                taskId,
+                SerializationExceptionOrigin.KEY
+            ));
+            collector.initialize();
+
+            final StreamsException exception = assertThrows(
+                StreamsException.class,
+                () -> collector.send(topic, "key", "val", null, 0, null, 
errorSerializer, stringSerializer, sinkNodeName, context)
+            );
+
+            assertEquals("Fatal user code error in production error callback", 
exception.getMessage());
+            assertEquals("CRASH", exception.getCause().getMessage());
+        }
+    }
+
+    @Test
+    public void 
shouldThrowStreamsExceptionOnSubsequentFlushIfASendFailsAndProductionExceptionHandlerReturnsNull()
 {
+        final KafkaException exception = new KafkaException("KABOOM!");
+        final RecordCollector collector = new RecordCollectorImpl(
+            logContext,
+            taskId,
+            getExceptionalStreamsProducerOnSend(exception),
+            new ProductionExceptionHandlerMock(
+                Optional.empty(),
+                context,
+                sinkNodeName,
+                taskId,
+                SerializationExceptionOrigin.KEY
+            ),
+            streamsMetrics,
+            topology
+        );
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, 
stringSerializer, sinkNodeName, context, streamPartitioner);
+
+        final StreamsException thrown = assertThrows(StreamsException.class, 
collector::flush);
+        assertEquals("Fatal user code error in production error callback", 
thrown.getMessage());
+        assertInstanceOf(NullPointerException.class, thrown.getCause());
+        assertEquals("Invalid ProductionExceptionHandler response.", 
thrown.getCause().getMessage());
+    }
+
+    @Test
+    public void 
shouldThrowStreamsExceptionOnSubsequentFlushIfASendFailsAndProductionExceptionHandlerThrows()
 {
+        final KafkaException exception = new KafkaException("KABOOM!");
+        final RecordCollector collector = new RecordCollectorImpl(
+            logContext,
+            taskId,
+            getExceptionalStreamsProducerOnSend(exception),
+            new ProductionExceptionHandlerMock(
+                true,
+                context,
+                sinkNodeName,
+                taskId,
+                SerializationExceptionOrigin.KEY
+            ),
+            streamsMetrics,
+            topology
+        );
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, 
stringSerializer, sinkNodeName, context, streamPartitioner);
+
+        final StreamsException thrown = assertThrows(StreamsException.class, 
collector::flush);
+        assertEquals("Fatal user code error in production error callback", 
thrown.getMessage());
+        assertEquals("CRASH", thrown.getCause().getMessage());
+    }
+
     @SuppressWarnings({"unchecked", "rawtypes"})
     @Test
     public void shouldNotCallProductionExceptionHandlerOnClassCastException() {
         try (final ErrorStringSerializer errorSerializer = new 
ErrorStringSerializer()) {
-            final RecordCollector collector = newRecordCollector(new 
ProductionExceptionHandlerMock(ProductionExceptionHandler.ProductionExceptionHandlerResponse.CONTINUE));
+            final RecordCollector collector = newRecordCollector(
+                new 
ProductionExceptionHandlerMock(Optional.of(ProductionExceptionHandlerResponse.CONTINUE))
+            );
             collector.initialize();
 
             assertThat(mockProducer.history().isEmpty(), equalTo(true));
@@ -1766,17 +1866,18 @@ public class RecordCollectorTest {
     }
 
     public static class ProductionExceptionHandlerMock implements 
ProductionExceptionHandler {
-        private final ProductionExceptionHandlerResponse response;
+        private final Optional<ProductionExceptionHandlerResponse> response;
+        private boolean shouldThrowException;
         private InternalProcessorContext<Void, Void> expectedContext;
         private String expectedProcessorNodeId;
         private TaskId expectedTaskId;
         private SerializationExceptionOrigin 
expectedSerializationExceptionOrigin;
 
-        public ProductionExceptionHandlerMock(final 
ProductionExceptionHandlerResponse response) {
+        public ProductionExceptionHandlerMock(final 
Optional<ProductionExceptionHandlerResponse> response) {
             this.response = response;
         }
 
-        public ProductionExceptionHandlerMock(final 
ProductionExceptionHandlerResponse response,
+        public ProductionExceptionHandlerMock(final 
Optional<ProductionExceptionHandlerResponse> response,
                                               final 
InternalProcessorContext<Void, Void> context,
                                               final String processorNodeId,
                                               final TaskId taskId) {
@@ -1786,13 +1887,24 @@ public class RecordCollectorTest {
             this.expectedTaskId = taskId;
         }
 
-        public ProductionExceptionHandlerMock(final 
ProductionExceptionHandlerResponse response,
+        public ProductionExceptionHandlerMock(final boolean 
shouldThrowException,
+                                              final 
InternalProcessorContext<Void, Void> context,
+                                              final String processorNodeId,
+                                              final TaskId taskId,
+                                              final 
SerializationExceptionOrigin origin) {
+            this(Optional.empty(), context, processorNodeId, taskId);
+            this.expectedSerializationExceptionOrigin = origin;
+            this.shouldThrowException = shouldThrowException;
+        }
+
+        public ProductionExceptionHandlerMock(final 
Optional<ProductionExceptionHandlerResponse> response,
                                               final 
InternalProcessorContext<Void, Void> context,
                                               final String processorNodeId,
                                               final TaskId taskId,
                                               final 
SerializationExceptionOrigin origin) {
             this(response, context, processorNodeId, taskId);
             this.expectedSerializationExceptionOrigin = origin;
+            this.shouldThrowException = false;
         }
 
         @Override
@@ -1800,7 +1912,10 @@ public class RecordCollectorTest {
                                                          final 
ProducerRecord<byte[], byte[]> record,
                                                          final Exception 
exception) {
             assertInputs(context, exception);
-            return response;
+            if (shouldThrowException) {
+                throw new RuntimeException("CRASH");
+            }
+            return response.orElse(null);
         }
 
         @Override
@@ -1810,7 +1925,10 @@ public class RecordCollectorTest {
                                                                                
final SerializationExceptionOrigin origin) {
             assertInputs(context, exception);
             assertEquals(expectedSerializationExceptionOrigin, origin);
-            return response;
+            if (shouldThrowException) {
+                throw new RuntimeException("CRASH");
+            }
+            return response.orElse(null);
         }
 
         @Override
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordDeserializerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordDeserializerTest.java
index 23f364fc6a3..1bca1c9e379 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordDeserializerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordDeserializerTest.java
@@ -25,6 +25,7 @@ import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.record.TimestampType;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.errors.DeserializationExceptionHandler;
+import 
org.apache.kafka.streams.errors.DeserializationExceptionHandler.DeserializationHandlerResponse;
 import org.apache.kafka.streams.errors.ErrorHandlerContext;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.processor.TaskId;
@@ -62,27 +63,29 @@ public class RecordDeserializerTest {
 
     @Test
     public void 
shouldReturnConsumerRecordWithDeserializedValueWhenNoExceptions() {
-        final RecordDeserializer recordDeserializer = new RecordDeserializer(
-            new TheSourceNode(
-                sourceNodeName,
-                false,
-                false,
-                "key",
-                "value"
-            ),
-            null,
-            new LogContext(),
-            new Metrics().sensor("dropped-records")
-        );
-        final ConsumerRecord<Object, Object> record = 
recordDeserializer.deserialize(null, rawRecord);
-        assertEquals(rawRecord.topic(), record.topic());
-        assertEquals(rawRecord.partition(), record.partition());
-        assertEquals(rawRecord.offset(), record.offset());
-        assertEquals("key", record.key());
-        assertEquals("value", record.value());
-        assertEquals(rawRecord.timestamp(), record.timestamp());
-        assertEquals(TimestampType.CREATE_TIME, record.timestampType());
-        assertEquals(rawRecord.headers(), record.headers());
+        try (final Metrics metrics = new Metrics()) {
+            final RecordDeserializer recordDeserializer = new 
RecordDeserializer(
+                    new TheSourceNode(
+                            sourceNodeName,
+                            false,
+                            false,
+                            "key",
+                            "value"
+                    ),
+                    null,
+                    new LogContext(),
+                    metrics.sensor("dropped-records")
+            );
+            final ConsumerRecord<Object, Object> record = 
recordDeserializer.deserialize(null, rawRecord);
+            assertEquals(rawRecord.topic(), record.topic());
+            assertEquals(rawRecord.partition(), record.partition());
+            assertEquals(rawRecord.offset(), record.offset());
+            assertEquals("key", record.key());
+            assertEquals("value", record.value());
+            assertEquals(rawRecord.timestamp(), record.timestamp());
+            assertEquals(TimestampType.CREATE_TIME, record.timestampType());
+            assertEquals(rawRecord.headers(), record.headers());
+        }
     }
 
     @ParameterizedTest
@@ -93,30 +96,35 @@ public class RecordDeserializerTest {
     })
     public void 
shouldThrowStreamsExceptionWhenDeserializationFailsAndExceptionHandlerRepliesWithFail(final
 boolean keyThrowsException,
                                                                                
                       final boolean valueThrowsException) {
-        final RecordDeserializer recordDeserializer = new RecordDeserializer(
-            new TheSourceNode(
-                sourceNodeName,
-                keyThrowsException,
-                valueThrowsException,
-                "key",
-                "value"
-            ),
-            new DeserializationExceptionHandlerMock(
-                
DeserializationExceptionHandler.DeserializationHandlerResponse.FAIL,
-                rawRecord,
-                sourceNodeName,
-                taskId
-            ),
-            new LogContext(),
-            new Metrics().sensor("dropped-records")
-        );
-
-        final StreamsException e = assertThrows(StreamsException.class, () -> 
recordDeserializer.deserialize(context, rawRecord));
-        assertEquals(e.getMessage(), "Deserialization exception handler is set 
"
-                + "to fail upon a deserialization error. "
-                + "If you would rather have the streaming pipeline "
-                + "continue after a deserialization error, please set the "
-                + "default.deserialization.exception.handler appropriately.");
+        try (final Metrics metrics = new Metrics()) {
+            final RecordDeserializer recordDeserializer = new 
RecordDeserializer(
+                    new TheSourceNode(
+                            sourceNodeName,
+                            keyThrowsException,
+                            valueThrowsException,
+                            "key",
+                            "value"
+                    ),
+                    new DeserializationExceptionHandlerMock(
+                            Optional.of(DeserializationHandlerResponse.FAIL),
+                            rawRecord,
+                            sourceNodeName,
+                            taskId
+                    ),
+                    new LogContext(),
+                    metrics.sensor("dropped-records")
+            );
+
+            final StreamsException e = assertThrows(StreamsException.class, () 
-> recordDeserializer.deserialize(context, rawRecord));
+            assertEquals(
+                    e.getMessage(),
+                    "Deserialization exception handler is set "
+                            + "to fail upon a deserialization error. "
+                            + "If you would rather have the streaming pipeline 
"
+                            + "continue after a deserialization error, please 
set the "
+                            + "default.deserialization.exception.handler 
appropriately."
+            );
+        }
     }
 
     @ParameterizedTest
@@ -127,26 +135,89 @@ public class RecordDeserializerTest {
     })
     public void 
shouldNotThrowStreamsExceptionWhenDeserializationFailsAndExceptionHandlerRepliesWithContinue(final
 boolean keyThrowsException,
                                                                                
                              final boolean valueThrowsException) {
-        final RecordDeserializer recordDeserializer = new RecordDeserializer(
-            new TheSourceNode(
-                sourceNodeName,
-                keyThrowsException,
-                valueThrowsException,
-                "key",
-                "value"
-            ),
-            new DeserializationExceptionHandlerMock(
-                
DeserializationExceptionHandler.DeserializationHandlerResponse.CONTINUE,
-                rawRecord,
-                sourceNodeName,
-                taskId
-            ),
-            new LogContext(),
-            new Metrics().sensor("dropped-records")
-        );
-
-        final ConsumerRecord<Object, Object> record = 
recordDeserializer.deserialize(context, rawRecord);
-        assertNull(record);
+        try (final Metrics metrics = new Metrics()) {
+            final RecordDeserializer recordDeserializer = new 
RecordDeserializer(
+                    new TheSourceNode(
+                            sourceNodeName,
+                            keyThrowsException,
+                            valueThrowsException,
+                            "key",
+                            "value"
+                    ),
+                    new DeserializationExceptionHandlerMock(
+                            
Optional.of(DeserializationHandlerResponse.CONTINUE),
+                            rawRecord,
+                            sourceNodeName,
+                            taskId
+                    ),
+                    new LogContext(),
+                    metrics.sensor("dropped-records")
+            );
+
+            final ConsumerRecord<Object, Object> record = 
recordDeserializer.deserialize(context, rawRecord);
+            assertNull(record);
+        }
+    }
+
+    @Test
+    public void 
shouldFailWhenDeserializationFailsAndExceptionHandlerReturnsNull() {
+        try (final Metrics metrics = new Metrics()) {
+            final RecordDeserializer recordDeserializer = new 
RecordDeserializer(
+                    new TheSourceNode(
+                            sourceNodeName,
+                            true,
+                            false,
+                            "key",
+                            "value"
+                    ),
+                    new DeserializationExceptionHandlerMock(
+                            Optional.empty(),
+                            rawRecord,
+                            sourceNodeName,
+                            taskId
+                    ),
+                    new LogContext(),
+                    metrics.sensor("dropped-records")
+            );
+
+            final StreamsException exception = assertThrows(
+                    StreamsException.class,
+                    () -> recordDeserializer.deserialize(context, rawRecord)
+            );
+            assertEquals("Fatal user code error in deserialization error 
callback", exception.getMessage());
+            assertInstanceOf(NullPointerException.class, exception.getCause());
+            assertEquals("Invalid DeserializationExceptionHandler response.", 
exception.getCause().getMessage());
+        }
+    }
+
+    @Test
+    public void shouldFailWhenDeserializationFailsAndExceptionHandlerThrows() {
+        try (final Metrics metrics = new Metrics()) {
+            final RecordDeserializer recordDeserializer = new 
RecordDeserializer(
+                    new TheSourceNode(
+                            sourceNodeName,
+                            true,
+                            false,
+                            "key",
+                            "value"
+                    ),
+                    new DeserializationExceptionHandlerMock(
+                            null, // indicate to throw an exception
+                            rawRecord,
+                            sourceNodeName,
+                            taskId
+                    ),
+                    new LogContext(),
+                    metrics.sensor("dropped-records")
+            );
+
+            final StreamsException exception = assertThrows(
+                    StreamsException.class,
+                    () -> recordDeserializer.deserialize(context, rawRecord)
+            );
+            assertEquals("Fatal user code error in deserialization error 
callback", exception.getMessage());
+            assertEquals("CRASH", exception.getCause().getMessage());
+        }
     }
 
     static class TheSourceNode extends SourceNode<Object, Object> {
@@ -185,12 +256,12 @@ public class RecordDeserializerTest {
     }
 
     public static class DeserializationExceptionHandlerMock implements 
DeserializationExceptionHandler {
-        private final DeserializationHandlerResponse response;
+        private final Optional<DeserializationHandlerResponse> response;
         private final ConsumerRecord<byte[], byte[]> expectedRecord;
         private final String expectedProcessorNodeId;
         private final TaskId expectedTaskId;
 
-        public DeserializationExceptionHandlerMock(final 
DeserializationHandlerResponse response,
+        public DeserializationExceptionHandlerMock(final 
Optional<DeserializationHandlerResponse> response,
                                                    final 
ConsumerRecord<byte[], byte[]> record,
                                                    final String 
processorNodeId,
                                                    final TaskId taskId) {
@@ -212,7 +283,10 @@ public class RecordDeserializerTest {
             assertEquals(expectedRecord, record);
             assertInstanceOf(RuntimeException.class, exception);
             assertEquals("KABOOM!", exception.getMessage());
-            return response;
+            if (response == null) {
+                throw new RuntimeException("CRASH");
+            }
+            return response.orElse(null);
         }
 
         @Override
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index a8771c21539..3fa33ef8954 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -53,6 +53,7 @@ import 
org.apache.kafka.streams.errors.LogAndContinueProcessingExceptionHandler;
 import org.apache.kafka.streams.errors.LogAndFailExceptionHandler;
 import org.apache.kafka.streams.errors.LogAndFailProcessingExceptionHandler;
 import org.apache.kafka.streams.errors.ProcessingExceptionHandler;
+import 
org.apache.kafka.streams.errors.ProcessingExceptionHandler.ProcessingHandlerResponse;
 import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
@@ -2661,17 +2662,20 @@ public class StreamTaskTest {
     }
 
     @Test
-    public void 
shouldPunctuateNotHandleFailProcessingExceptionAndThrowStreamsException() {
+    public void 
punctuateShouldNotHandleFailProcessingExceptionAndThrowStreamsException() {
         when(stateManager.taskId()).thenReturn(taskId);
         when(stateManager.taskType()).thenReturn(TaskType.ACTIVE);
-        task = createStatelessTask(createConfig(AT_LEAST_ONCE, "100",
-            LogAndFailExceptionHandler.class.getName(), 
LogAndContinueProcessingExceptionHandler.class.getName()));
+        task = createStatelessTask(createConfig(
+            AT_LEAST_ONCE,
+            "100",
+            LogAndFailExceptionHandler.class.getName(),
+            LogAndContinueProcessingExceptionHandler.class.getName()
+        ));
 
-        final StreamsException streamsException = 
assertThrows(StreamsException.class, () ->
-            task.punctuate(processorStreamTime, 1, 
PunctuationType.STREAM_TIME, timestamp -> {
-                throw new FailedProcessingException(
-                    new RuntimeException("KABOOM!")
-                );
+        final StreamsException streamsException = assertThrows(
+            StreamsException.class,
+            () -> task.punctuate(processorStreamTime, 1, 
PunctuationType.STREAM_TIME, timestamp -> {
+                throw new FailedProcessingException(new 
RuntimeException("KABOOM!"));
             })
         );
 
@@ -2680,11 +2684,15 @@ public class StreamTaskTest {
     }
 
     @Test
-    public void shouldPunctuateNotHandleTaskCorruptedExceptionAndThrowItAsIs() 
{
+    public void punctuateShouldNotHandleTaskCorruptedExceptionAndThrowItAsIs() 
{
         when(stateManager.taskId()).thenReturn(taskId);
         when(stateManager.taskType()).thenReturn(TaskType.ACTIVE);
-        task = createStatelessTask(createConfig(AT_LEAST_ONCE, "100",
-            LogAndFailExceptionHandler.class.getName(), 
LogAndContinueProcessingExceptionHandler.class.getName()));
+        task = createStatelessTask(createConfig(
+            AT_LEAST_ONCE,
+            "100",
+            LogAndFailExceptionHandler.class.getName(),
+            LogAndContinueProcessingExceptionHandler.class.getName()
+        ));
 
         final Set<TaskId> tasksIds = new HashSet<>();
         tasksIds.add(new TaskId(0, 0));
@@ -2695,8 +2703,9 @@ public class StreamTaskTest {
             }
         });
 
-        final TaskCorruptedException taskCorruptedException = 
assertThrows(TaskCorruptedException.class, () ->
-            task.punctuate(processorStreamTime, 1, 
PunctuationType.STREAM_TIME, timestamp -> {
+        final TaskCorruptedException taskCorruptedException = assertThrows(
+            TaskCorruptedException.class,
+            () -> task.punctuate(processorStreamTime, 1, 
PunctuationType.STREAM_TIME, timestamp -> {
                 throw expectedException;
             })
         );
@@ -2705,16 +2714,21 @@ public class StreamTaskTest {
     }
 
     @Test
-    public void shouldPunctuateNotHandleTaskMigratedExceptionAndThrowItAsIs() {
+    public void punctuateShouldNotHandleTaskMigratedExceptionAndThrowItAsIs() {
         when(stateManager.taskId()).thenReturn(taskId);
         when(stateManager.taskType()).thenReturn(TaskType.ACTIVE);
-        task = createStatelessTask(createConfig(AT_LEAST_ONCE, "100",
-            LogAndFailExceptionHandler.class.getName(), 
LogAndContinueProcessingExceptionHandler.class.getName()));
+        task = createStatelessTask(createConfig(
+            AT_LEAST_ONCE,
+            "100",
+            LogAndFailExceptionHandler.class.getName(),
+            LogAndContinueProcessingExceptionHandler.class.getName()
+        ));
 
         final TaskMigratedException expectedException = new 
TaskMigratedException("TaskMigratedException", new RuntimeException("Task 
migrated cause"));
 
-        final TaskMigratedException taskCorruptedException = 
assertThrows(TaskMigratedException.class, () ->
-            task.punctuate(processorStreamTime, 1, 
PunctuationType.STREAM_TIME, timestamp -> {
+        final TaskMigratedException taskCorruptedException = assertThrows(
+            TaskMigratedException.class,
+            () -> task.punctuate(processorStreamTime, 1, 
PunctuationType.STREAM_TIME, timestamp -> {
                 throw expectedException;
             })
         );
@@ -2723,56 +2737,106 @@ public class StreamTaskTest {
     }
 
     @Test
-    public void 
shouldPunctuateNotThrowStreamsExceptionWhenProcessingExceptionHandlerRepliesWithContinue()
 {
+    public void 
punctuateShouldNotThrowStreamsExceptionWhenProcessingExceptionHandlerRepliesWithContinue()
 {
         when(stateManager.taskId()).thenReturn(taskId);
         when(stateManager.taskType()).thenReturn(TaskType.ACTIVE);
-        task = createStatelessTask(createConfig(AT_LEAST_ONCE, "100",
-            LogAndFailExceptionHandler.class.getName(), 
LogAndContinueProcessingExceptionHandler.class.getName()));
+        task = createStatelessTask(createConfig(
+            AT_LEAST_ONCE,
+            "100",
+            LogAndFailExceptionHandler.class.getName(),
+            LogAndContinueProcessingExceptionHandler.class.getName()
+        ));
 
-        assertDoesNotThrow(() ->
-            task.punctuate(processorStreamTime, 1, 
PunctuationType.STREAM_TIME, timestamp -> {
+        task.punctuate(processorStreamTime, 1, PunctuationType.STREAM_TIME, 
timestamp -> {
+            throw new KafkaException("KABOOM!");
+        });
+    }
+
+    @Test
+    public void 
punctuateShouldThrowStreamsExceptionWhenProcessingExceptionHandlerRepliesWithFail()
 {
+        when(stateManager.taskId()).thenReturn(taskId);
+        when(stateManager.taskType()).thenReturn(TaskType.ACTIVE);
+        task = createStatelessTask(createConfig(
+            AT_LEAST_ONCE,
+            "100",
+            LogAndFailExceptionHandler.class.getName(),
+            LogAndFailProcessingExceptionHandler.class.getName()
+        ));
+
+        final StreamsException streamsException = assertThrows(
+            StreamsException.class,
+            () -> task.punctuate(processorStreamTime, 1, 
PunctuationType.STREAM_TIME, timestamp -> {
                 throw new KafkaException("KABOOM!");
             })
         );
+
+        assertInstanceOf(KafkaException.class, streamsException.getCause());
+        assertEquals("KABOOM!", streamsException.getCause().getMessage());
     }
 
     @Test
-    public void 
shouldPunctuateThrowStreamsExceptionWhenProcessingExceptionHandlerRepliesWithFail()
 {
+    public void 
punctuateShouldThrowStreamsExceptionWhenProcessingExceptionHandlerReturnsNull() 
{
         when(stateManager.taskId()).thenReturn(taskId);
         when(stateManager.taskType()).thenReturn(TaskType.ACTIVE);
-        task = createStatelessTask(createConfig(AT_LEAST_ONCE, "100",
-            LogAndFailExceptionHandler.class.getName(), 
LogAndFailProcessingExceptionHandler.class.getName()));
+        task = createStatelessTask(createConfig(
+            AT_LEAST_ONCE,
+            "100",
+            LogAndFailExceptionHandler.class.getName(),
+            NullProcessingExceptionHandler.class.getName()
+        ));
 
-        final StreamsException streamsException = 
assertThrows(StreamsException.class,
+        final StreamsException streamsException = assertThrows(
+            StreamsException.class,
             () -> task.punctuate(processorStreamTime, 1, 
PunctuationType.STREAM_TIME, timestamp -> {
                 throw new KafkaException("KABOOM!");
-            }));
+            })
+        );
 
-        assertInstanceOf(KafkaException.class, streamsException.getCause());
-        assertEquals("KABOOM!", streamsException.getCause().getMessage());
+        assertEquals("Fatal user code error in processing error callback", 
streamsException.getMessage());
+        assertInstanceOf(NullPointerException.class, 
streamsException.getCause());
+        assertEquals("Invalid ProcessingExceptionHandler response.", 
streamsException.getCause().getMessage());
     }
 
     @Test
-    public void 
shouldPunctuateThrowFailedProcessingExceptionWhenProcessingExceptionHandlerThrowsAnException()
 {
+    public void 
punctuateShouldThrowFailedProcessingExceptionWhenProcessingExceptionHandlerThrowsAnException()
 {
         when(stateManager.taskId()).thenReturn(taskId);
         when(stateManager.taskType()).thenReturn(TaskType.ACTIVE);
-        task = createStatelessTask(createConfig(AT_LEAST_ONCE, "100",
-                LogAndFailExceptionHandler.class.getName(), 
ProcessingExceptionHandlerMock.class.getName()));
+        task = createStatelessTask(createConfig(
+            AT_LEAST_ONCE,
+            "100",
+            LogAndFailExceptionHandler.class.getName(),
+            CrashingProcessingExceptionHandler.class.getName()
+        ));
 
-        final FailedProcessingException streamsException = 
assertThrows(FailedProcessingException.class,
+        final FailedProcessingException streamsException = assertThrows(
+            FailedProcessingException.class,
             () -> task.punctuate(processorStreamTime, 1, 
PunctuationType.STREAM_TIME, timestamp -> {
                 throw new KafkaException("KABOOM!");
-            }));
+            })
+        );
 
-        assertInstanceOf(RuntimeException.class, streamsException.getCause());
+        assertEquals("Fatal user code error in processing error callback", 
streamsException.getMessage());
         assertEquals("KABOOM from ProcessingExceptionHandlerMock!", 
streamsException.getCause().getMessage());
     }
 
-    public static class ProcessingExceptionHandlerMock implements 
ProcessingExceptionHandler {
+    public static class CrashingProcessingExceptionHandler implements 
ProcessingExceptionHandler {
         @Override
-        public ProcessingExceptionHandler.ProcessingHandlerResponse 
handle(final ErrorHandlerContext context, final Record<?, ?> record, final 
Exception exception) {
+        public ProcessingHandlerResponse handle(final ErrorHandlerContext 
context, final Record<?, ?> record, final Exception exception) {
             throw new RuntimeException("KABOOM from 
ProcessingExceptionHandlerMock!");
         }
+
+        @Override
+        public void configure(final Map<String, ?> configs) {
+            // No-op
+        }
+    }
+
+    public static class NullProcessingExceptionHandler implements 
ProcessingExceptionHandler {
+        @Override
+        public ProcessingHandlerResponse handle(final ErrorHandlerContext 
context, final Record<?, ?> record, final Exception exception) {
+            return null;
+        }
+
         @Override
         public void configure(final Map<String, ?> configs) {
             // No-op

Reply via email to