This is an automated email from the ASF dual-hosted git repository.

lucasbru pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 34e426b5178 KAFKA-19271:  Add internal ConsumerWrapper (#19697)
34e426b5178 is described below

commit 34e426b5178a5376668df6a2c6cf130fce78d431
Author: Matthias J. Sax <[email protected]>
AuthorDate: Fri May 16 02:57:37 2025 -0700

    KAFKA-19271:  Add internal ConsumerWrapper (#19697)
    
    With KIP-1071 enabled, the main consumer is created differently,  side
    stepping `KafkaClientSupplier`.
    
    To allow injection test wrappers, we add an internal ConsumerWrapper,
    until we define a new public interface.
    
    Reviewers: Lucas Brutschy <[email protected]>
---
 .../KafkaStreamsTelemetryIntegrationTest.java      | 148 +++++++---
 .../org/apache/kafka/streams/StreamsConfig.java    |   3 +
 .../kafka/streams/internals/ConsumerWrapper.java   | 316 +++++++++++++++++++++
 .../streams/processor/internals/StreamThread.java  |  45 ++-
 .../processor/internals/StreamThreadTest.java      |  33 +++
 5 files changed, 508 insertions(+), 37 deletions(-)

diff --git 
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/KafkaStreamsTelemetryIntegrationTest.java
 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/KafkaStreamsTelemetryIntegrationTest.java
index 0af02321a67..eefb4e3e287 100644
--- 
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/KafkaStreamsTelemetryIntegrationTest.java
+++ 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/KafkaStreamsTelemetryIntegrationTest.java
@@ -22,6 +22,8 @@ import org.apache.kafka.clients.admin.AdminClientConfig;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.KafkaConsumer;
+import org.apache.kafka.clients.consumer.internals.AsyncKafkaConsumer;
+import org.apache.kafka.clients.consumer.internals.StreamsRebalanceData;
 import org.apache.kafka.clients.producer.KafkaProducer;
 import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.common.Metric;
@@ -46,6 +48,7 @@ import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.Topology;
 import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.internals.ConsumerWrapper;
 import org.apache.kafka.streams.kstream.Consumed;
 import org.apache.kafka.streams.kstream.Produced;
 import org.apache.kafka.streams.processor.api.Processor;
@@ -62,7 +65,6 @@ import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Tag;
-import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.TestInfo;
 import org.junit.jupiter.api.Timeout;
 import org.junit.jupiter.params.ParameterizedTest;
@@ -110,13 +112,23 @@ public class KafkaStreamsTelemetryIntegrationTest {
     private  KeyValueIterator<String, String> globalStoreIterator;
 
     private static EmbeddedKafkaCluster cluster;
-    private static final List<TestingMetricsInterceptingConsumer<byte[], 
byte[]>> INTERCEPTING_CONSUMERS = new ArrayList<>();
+    private static final List<TestingMetricsInterceptor> 
INTERCEPTING_CONSUMERS = new ArrayList<>();
     private static final List<TestingMetricsInterceptingAdminClient> 
INTERCEPTING_ADMIN_CLIENTS = new ArrayList<>();
     private static final int NUM_BROKERS = 3;
     private static final int FIRST_INSTANCE_CLIENT = 0;
     private static final int SECOND_INSTANCE_CLIENT = 1;
     private static final Logger LOG = 
LoggerFactory.getLogger(KafkaStreamsTelemetryIntegrationTest.class);
 
+    static Stream<Arguments> recordingLevelParameters() {
+        return Stream.of(
+            Arguments.of("INFO", "classic"),
+            Arguments.of("DEBUG", "classic"),
+            Arguments.of("TRACE", "classic"),
+            Arguments.of("INFO", "streams"),
+            Arguments.of("DEBUG", "streams"),
+            Arguments.of("TRACE", "streams")
+        );
+    }
 
     @BeforeAll
     public static void startCluster() throws IOException {
@@ -160,9 +172,9 @@ public class KafkaStreamsTelemetryIntegrationTest {
     }
 
     @ParameterizedTest
-    @ValueSource(strings = {"INFO", "DEBUG", "TRACE"})
-    public void shouldPushGlobalThreadMetricsToBroker(final String 
recordingLevel) throws Exception {
-        streamsApplicationProperties = props(true);
+    @MethodSource("recordingLevelParameters")
+    public void shouldPushGlobalThreadMetricsToBroker(final String 
recordingLevel, final String groupProtocol) throws Exception {
+        streamsApplicationProperties = props(true, groupProtocol);
         
streamsApplicationProperties.put(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, 
recordingLevel);
         final Topology topology = simpleTopology(true);
         subscribeForStreamsMetrics();
@@ -198,10 +210,10 @@ public class KafkaStreamsTelemetryIntegrationTest {
     }
 
     @ParameterizedTest
-    @ValueSource(strings = {"INFO", "DEBUG", "TRACE"})
-    public void shouldPushMetricsToBroker(final String recordingLevel) throws 
Exception {
+    @MethodSource("recordingLevelParameters")
+    public void shouldPushMetricsToBroker(final String recordingLevel, final 
String groupProtocol) throws Exception {
         // End-to-end test validating metrics pushed to broker
-        streamsApplicationProperties  = props(true);
+        streamsApplicationProperties  = props(true, groupProtocol);
         
streamsApplicationProperties.put(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, 
recordingLevel);
         final Topology topology = simpleTopology(false);
         subscribeForStreamsMetrics();
@@ -263,9 +275,9 @@ public class KafkaStreamsTelemetryIntegrationTest {
 
     @ParameterizedTest
     @MethodSource("singleAndMultiTaskParameters")
-    public void shouldPassMetrics(final String topologyType, final boolean 
stateUpdaterEnabled) throws Exception {
+    public void shouldPassMetrics(final String topologyType, final boolean 
stateUpdaterEnabled, final String groupProtocol) throws Exception {
         // Streams metrics should get passed to Admin and Consumer
-        streamsApplicationProperties = props(stateUpdaterEnabled);
+        streamsApplicationProperties = props(stateUpdaterEnabled, 
groupProtocol);
         final Topology topology = topologyType.equals("simple") ? 
simpleTopology(false) : complexTopology();
        
         try (final KafkaStreams streams = new KafkaStreams(topology, 
streamsApplicationProperties)) {
@@ -279,7 +291,7 @@ public class KafkaStreamsTelemetryIntegrationTest {
 
 
 
-            final List<MetricName> consumerPassedStreamThreadMetricNames = 
INTERCEPTING_CONSUMERS.get(FIRST_INSTANCE_CLIENT).passedMetrics.stream().map(KafkaMetric::metricName).toList();
+            final List<MetricName> consumerPassedStreamThreadMetricNames = 
INTERCEPTING_CONSUMERS.get(FIRST_INSTANCE_CLIENT).passedMetrics().stream().map(KafkaMetric::metricName).toList();
             final List<MetricName> adminPassedStreamClientMetricNames = 
INTERCEPTING_ADMIN_CLIENTS.get(FIRST_INSTANCE_CLIENT).passedMetrics.stream().map(KafkaMetric::metricName).toList();
 
 
@@ -293,14 +305,14 @@ public class KafkaStreamsTelemetryIntegrationTest {
 
     @ParameterizedTest
     @MethodSource("multiTaskParameters")
-    public void shouldPassCorrectMetricsDynamicInstances(final boolean 
stateUpdaterEnabled) throws Exception {
+    public void shouldPassCorrectMetricsDynamicInstances(final boolean 
stateUpdaterEnabled, final String groupProtocol) throws Exception {
         // Correct streams metrics should get passed with dynamic membership
-        streamsApplicationProperties = props(stateUpdaterEnabled);
+        streamsApplicationProperties = props(stateUpdaterEnabled, 
groupProtocol);
         streamsApplicationProperties.put(StreamsConfig.STATE_DIR_CONFIG, 
TestUtils.tempDirectory(appId).getPath() + "-ks1");
         streamsApplicationProperties.put(StreamsConfig.CLIENT_ID_CONFIG, appId 
+ "-ks1");
 
 
-        streamsSecondApplicationProperties = props(stateUpdaterEnabled);
+        streamsSecondApplicationProperties = props(stateUpdaterEnabled, 
groupProtocol);
         streamsSecondApplicationProperties.put(StreamsConfig.STATE_DIR_CONFIG, 
TestUtils.tempDirectory(appId).getPath() + "-ks2");
         streamsSecondApplicationProperties.put(StreamsConfig.CLIENT_ID_CONFIG, 
appId + "-ks2");
         
@@ -312,7 +324,7 @@ public class KafkaStreamsTelemetryIntegrationTest {
             final List<MetricName> streamsTaskMetricNames = 
streamsOne.metrics().values().stream().map(Metric::metricName)
                     .filter(metricName -> 
metricName.tags().containsKey("task-id")).toList();
 
-            final List<MetricName> consumerPassedStreamTaskMetricNames = 
INTERCEPTING_CONSUMERS.get(FIRST_INSTANCE_CLIENT).passedMetrics.stream().map(KafkaMetric::metricName)
+            final List<MetricName> consumerPassedStreamTaskMetricNames = 
INTERCEPTING_CONSUMERS.get(FIRST_INSTANCE_CLIENT).passedMetrics().stream().map(KafkaMetric::metricName)
                     .filter(metricName -> 
metricName.tags().containsKey("task-id")).toList();
 
             /*
@@ -349,9 +361,9 @@ public class KafkaStreamsTelemetryIntegrationTest {
                         .filter(metricName -> 
metricName.group().equals("stream-state-metrics")).toList();
                 
                 final List<MetricName> consumerOnePassedTaskMetrics = 
INTERCEPTING_CONSUMERS.get(FIRST_INSTANCE_CLIENT)
-                        
.passedMetrics.stream().map(KafkaMetric::metricName).filter(metricName -> 
metricName.tags().containsKey("task-id")).toList();
+                        
.passedMetrics().stream().map(KafkaMetric::metricName).filter(metricName -> 
metricName.tags().containsKey("task-id")).toList();
                 final List<MetricName> consumerOnePassedStateMetrics = 
INTERCEPTING_CONSUMERS.get(FIRST_INSTANCE_CLIENT)
-                        
.passedMetrics.stream().map(KafkaMetric::metricName).filter(metricName -> 
metricName.group().equals("stream-state-metrics")).toList();
+                        
.passedMetrics().stream().map(KafkaMetric::metricName).filter(metricName -> 
metricName.group().equals("stream-state-metrics")).toList();
 
                 final List<MetricName> streamsTwoTaskMetrics = 
streamsTwo.metrics().values().stream().map(Metric::metricName)
                         .filter(metricName -> 
metricName.tags().containsKey("task-id")).toList();
@@ -359,9 +371,9 @@ public class KafkaStreamsTelemetryIntegrationTest {
                         .filter(metricName -> 
metricName.group().equals("stream-state-metrics")).toList();
                 
                 final List<MetricName> consumerTwoPassedTaskMetrics = 
INTERCEPTING_CONSUMERS.get(SECOND_INSTANCE_CLIENT)
-                        
.passedMetrics.stream().map(KafkaMetric::metricName).filter(metricName -> 
metricName.tags().containsKey("task-id")).toList();
+                        
.passedMetrics().stream().map(KafkaMetric::metricName).filter(metricName -> 
metricName.tags().containsKey("task-id")).toList();
                 final List<MetricName> consumerTwoPassedStateMetrics = 
INTERCEPTING_CONSUMERS.get(SECOND_INSTANCE_CLIENT)
-                        
.passedMetrics.stream().map(KafkaMetric::metricName).filter(metricName -> 
metricName.group().equals("stream-state-metrics")).toList();
+                        
.passedMetrics().stream().map(KafkaMetric::metricName).filter(metricName -> 
metricName.group().equals("stream-state-metrics")).toList();
                 /*
                  Confirm pre-existing KafkaStreams instance one only passes 
metrics for its tasks and has no metrics for previous tasks
                  */
@@ -391,10 +403,11 @@ public class KafkaStreamsTelemetryIntegrationTest {
         }
     }
 
-    @Test
-    public void passedMetricsShouldNotLeakIntoClientMetrics() throws Exception 
{
+    @ParameterizedTest
+    @ValueSource(strings = {"classic", "streams"})
+    public void passedMetricsShouldNotLeakIntoClientMetrics(final String 
groupProtocol) throws Exception {
         // Streams metrics should not be visible in client metrics
-        streamsApplicationProperties = props(true);
+        streamsApplicationProperties = props(true, groupProtocol);
         final Topology topology =  complexTopology();
 
         try (final KafkaStreams streams = new KafkaStreams(topology, 
streamsApplicationProperties)) {
@@ -423,6 +436,7 @@ public class KafkaStreamsTelemetryIntegrationTest {
             clientMetricsService.alterClientMetrics(commandOptions);
         }
     }
+
     private List<String> getTaskIdsAsStrings(final KafkaStreams streams) {
         return streams.metadataForLocalThreads().stream()
                 .flatMap(threadMeta -> threadMeta.activeTasks().stream()
@@ -431,19 +445,32 @@ public class KafkaStreamsTelemetryIntegrationTest {
     }
 
     private static Stream<Arguments> singleAndMultiTaskParameters() {
-        return Stream.of(Arguments.of("simple", true),
-                Arguments.of("simple", false),
-                Arguments.of("complex", true),
-                Arguments.of("complex", false));
+        return Stream.of(
+            Arguments.of("simple", true, "classic"),
+            Arguments.of("simple", false, "classic"),
+            Arguments.of("complex", true, "classic"),
+            Arguments.of("complex", false, "classic"),
+            Arguments.of("simple", true, "streams"),
+            Arguments.of("simple", false, "streams"),
+            Arguments.of("complex", true, "streams"),
+            Arguments.of("complex", false, "streams")
+        );
     }
 
     private static Stream<Arguments> multiTaskParameters() {
-        return Stream.of(Arguments.of(true),
-                Arguments.of(false));
+        return Stream.of(
+            Arguments.of(true, "classic"),
+            Arguments.of(false, "classic"),
+            Arguments.of(true, "streams"),
+            Arguments.of(false, "streams")
+        );
     }
 
-    private Properties props(final boolean stateUpdaterEnabled) {
-        return 
props(mkObjectProperties(mkMap(mkEntry(StreamsConfig.InternalConfig.STATE_UPDATER_ENABLED,
 stateUpdaterEnabled))));
+    private Properties props(final boolean stateUpdaterEnabled, final String 
groupProtocol) {
+        return props(mkObjectProperties(mkMap(
+            mkEntry(StreamsConfig.InternalConfig.STATE_UPDATER_ENABLED, 
stateUpdaterEnabled),
+            mkEntry(StreamsConfig.GROUP_PROTOCOL_CONFIG, groupProtocol)
+        )));
     }
 
     private Properties props(final Properties extraProperties) {
@@ -455,6 +482,7 @@ public class KafkaStreamsTelemetryIntegrationTest {
         streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, 
Serdes.StringSerde.class);
         
streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, 
Serdes.StringSerde.class);
         streamsConfiguration.put(StreamsConfig.DEFAULT_CLIENT_SUPPLIER_CONFIG, 
TestClientSupplier.class);
+        
streamsConfiguration.put(StreamsConfig.InternalConfig.INTERNAL_CONSUMER_WRAPPER,
 TestConsumerWrapper.class);
         streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, 
"earliest");
         streamsConfiguration.putAll(extraProperties);
         return streamsConfiguration;
@@ -545,9 +573,23 @@ public class KafkaStreamsTelemetryIntegrationTest {
         }
     }
 
-    public static class TestingMetricsInterceptingConsumer<K, V> extends 
KafkaConsumer<K, V> {
+    public static class TestConsumerWrapper extends ConsumerWrapper {
+        @Override
+        public void wrapConsumer(final AsyncKafkaConsumer<byte[], byte[]> 
delegate, final Map<String, Object> config, final 
Optional<StreamsRebalanceData> streamsRebalanceData) {
+            final TestingMetricsInterceptingAsynConsumer<byte[], byte[]> 
consumer = new TestingMetricsInterceptingAsynConsumer<>(config, new 
ByteArrayDeserializer(), new ByteArrayDeserializer(), streamsRebalanceData);
+            INTERCEPTING_CONSUMERS.add(consumer);
+
+            super.wrapConsumer(consumer, config, streamsRebalanceData);
+        }
+    }
+
+    public interface TestingMetricsInterceptor {
+        List<KafkaMetric> passedMetrics();
+        Map<MetricName, ? extends Metric> metrics();
+    }
 
-        public List<KafkaMetric> passedMetrics = new ArrayList<>();
+    public static class TestingMetricsInterceptingConsumer<K, V> extends 
KafkaConsumer<K, V> implements TestingMetricsInterceptor {
+        private final List<KafkaMetric> passedMetrics = new ArrayList<>();
 
         public TestingMetricsInterceptingConsumer(final Map<String, Object> 
configs, final Deserializer<K> keyDeserializer, final Deserializer<V> 
valueDeserializer) {
             super(configs, keyDeserializer, valueDeserializer);
@@ -564,6 +606,48 @@ public class KafkaStreamsTelemetryIntegrationTest {
             passedMetrics.remove(metric);
             super.unregisterMetricFromSubscription(metric);
         }
+
+        @Override
+        public List<KafkaMetric> passedMetrics() {
+            return passedMetrics;
+        }
+    }
+
+    public static class TestingMetricsInterceptingAsynConsumer<K, V> extends 
AsyncKafkaConsumer<K, V> implements TestingMetricsInterceptor {
+        private final List<KafkaMetric> passedMetrics = new ArrayList<>();
+
+        public TestingMetricsInterceptingAsynConsumer(
+            final Map<String, Object> configs,
+            final Deserializer<K> keyDeserializer,
+            final Deserializer<V> valueDeserializer,
+            final Optional<StreamsRebalanceData> streamsRebalanceData
+        ) {
+            super(
+                new ConsumerConfig(
+                    ConsumerConfig.appendDeserializerToConfig(configs, 
keyDeserializer, valueDeserializer)
+                ),
+                keyDeserializer,
+                valueDeserializer,
+                streamsRebalanceData
+            );
+        }
+
+        @Override
+        public void registerMetricForSubscription(final KafkaMetric metric) {
+            passedMetrics.add(metric);
+            super.registerMetricForSubscription(metric);
+        }
+
+        @Override
+        public void unregisterMetricFromSubscription(final KafkaMetric metric) 
{
+            passedMetrics.remove(metric);
+            super.unregisterMetricFromSubscription(metric);
+        }
+
+        @Override
+        public List<KafkaMetric> passedMetrics() {
+            return passedMetrics;
+        }
     }
 
     public static class TelemetryPlugin implements ClientTelemetry, 
MetricsReporter, ClientTelemetryReceiver {
diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java 
b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
index e19cb03b207..c8fa251b6be 100644
--- a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
+++ b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
@@ -1320,6 +1320,9 @@ public class StreamsConfig extends AbstractConfig {
         // This is settable in the main Streams config, but it's a private API 
for testing
         public static final String ASSIGNMENT_LISTENER = 
"__assignment.listener__";
 
+        // This is settable in the main Streams config, but it's a private API 
for testing
+        public static final String INTERNAL_CONSUMER_WRAPPER = 
"__internal.consumer.wrapper__";
+
         // Private API used to control the emit latency for left/outer join 
results (https://issues.apache.org/jira/browse/KAFKA-10847)
         public static final String 
EMIT_INTERVAL_MS_KSTREAMS_OUTER_JOIN_SPURIOUS_RESULTS_FIX = 
"__emit.interval.ms.kstreams.outer.join.spurious.results.fix__";
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/internals/ConsumerWrapper.java 
b/streams/src/main/java/org/apache/kafka/streams/internals/ConsumerWrapper.java
new file mode 100644
index 00000000000..20cd7cb84d7
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/internals/ConsumerWrapper.java
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.internals;
+
+import org.apache.kafka.clients.consumer.CloseOptions;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerGroupMetadata;
+import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.clients.consumer.OffsetAndTimestamp;
+import org.apache.kafka.clients.consumer.OffsetCommitCallback;
+import org.apache.kafka.clients.consumer.SubscriptionPattern;
+import org.apache.kafka.clients.consumer.internals.AsyncKafkaConsumer;
+import org.apache.kafka.clients.consumer.internals.StreamsRebalanceData;
+import org.apache.kafka.common.Metric;
+import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.metrics.KafkaMetric;
+
+import java.time.Duration;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.OptionalLong;
+import java.util.Set;
+import java.util.regex.Pattern;
+
+public abstract class ConsumerWrapper implements Consumer<byte[], byte[]> {
+    protected AsyncKafkaConsumer<byte[], byte[]> delegate;
+
+    public void wrapConsumer(
+        final AsyncKafkaConsumer<byte[], byte[]> delegate,
+        final Map<String, Object> config,
+        final Optional<StreamsRebalanceData> streamsRebalanceData
+    ) {
+        this.delegate = delegate;
+    }
+
+    public AsyncKafkaConsumer<byte[], byte[]> consumer() {
+        return delegate;
+    }
+
+    @Override
+    public Set<TopicPartition> assignment() {
+        return delegate.assignment();
+    }
+
+    @Override
+    public Set<String> subscription() {
+        return delegate.subscription();
+    }
+
+    @Override
+    public void subscribe(final Collection<String> topics) {
+        delegate.subscribe(topics);
+    }
+
+    @Override
+    public void subscribe(final Collection<String> topics, final 
ConsumerRebalanceListener callback) {
+        delegate.subscribe(topics, callback);
+    }
+
+    @Override
+    public void assign(final Collection<TopicPartition> partitions) {
+        delegate.assign(partitions);
+    }
+
+    @Override
+    public void subscribe(final Pattern pattern, final 
ConsumerRebalanceListener callback) {
+        delegate.subscribe(pattern, callback);
+    }
+
+    @Override
+    public void subscribe(final Pattern pattern) {
+        delegate.subscribe(pattern);
+    }
+
+    @Override
+    public void subscribe(final SubscriptionPattern pattern, final 
ConsumerRebalanceListener callback) {
+        delegate.subscribe(pattern, callback);
+    }
+
+    @Override
+    public void subscribe(final SubscriptionPattern pattern) {
+        delegate.subscribe(pattern);
+    }
+
+    @Override
+    public void unsubscribe() {
+        delegate.unsubscribe();
+    }
+
+    @Override
+    public ConsumerRecords<byte[], byte[]> poll(final Duration timeout) {
+        return delegate.poll(timeout);
+    }
+
+    @Override
+    public void commitSync() {
+        delegate.commitSync();
+    }
+
+    @Override
+    public void commitSync(final Duration timeout) {
+        delegate.commitSync(timeout);
+    }
+
+    @Override
+    public void commitSync(final Map<TopicPartition, OffsetAndMetadata> 
offsets) {
+        delegate.commitSync(offsets);
+    }
+
+    @Override
+    public void commitSync(final Map<TopicPartition, OffsetAndMetadata> 
offsets, final Duration timeout) {
+        delegate.commitSync(offsets, timeout);
+    }
+
+    @Override
+    public void commitAsync() {
+        delegate.commitAsync();
+    }
+
+    @Override
+    public void commitAsync(final OffsetCommitCallback callback) {
+        delegate.commitAsync(callback);
+    }
+
+    @Override
+    public void commitAsync(final Map<TopicPartition, OffsetAndMetadata> 
offsets, final OffsetCommitCallback callback) {
+        delegate.commitAsync(offsets, callback);
+    }
+
+    @Override
+    public void registerMetricForSubscription(final KafkaMetric metric) {
+        delegate.registerMetricForSubscription(metric);
+    }
+
+    @Override
+    public void unregisterMetricFromSubscription(final KafkaMetric metric) {
+        delegate.unregisterMetricFromSubscription(metric);
+    }
+
+    @Override
+    public void seek(final TopicPartition partition, final long offset) {
+        delegate.seek(partition, offset);
+    }
+
+    @Override
+    public void seek(final TopicPartition partition, final OffsetAndMetadata 
offsetAndMetadata) {
+        delegate.seek(partition, offsetAndMetadata);
+    }
+
+    @Override
+    public void seekToBeginning(final Collection<TopicPartition> partitions) {
+        delegate.seekToBeginning(partitions);
+    }
+
+    @Override
+    public void seekToEnd(final Collection<TopicPartition> partitions) {
+        delegate.seekToEnd(partitions);
+    }
+
+    @Override
+    public long position(final TopicPartition partition) {
+        return delegate.position(partition);
+    }
+
+    @Override
+    public long position(final TopicPartition partition, final Duration 
timeout) {
+        return delegate.position(partition, timeout);
+    }
+
+    @Override
+    public Map<TopicPartition, OffsetAndMetadata> committed(final 
Set<TopicPartition> partitions) {
+        return delegate.committed(partitions);
+    }
+
+    @Override
+    public Map<TopicPartition, OffsetAndMetadata> committed(final 
Set<TopicPartition> partitions, final Duration timeout) {
+        return delegate.committed(partitions, timeout);
+    }
+
+    @Override
+    public Uuid clientInstanceId(final Duration timeout) {
+        return delegate.clientInstanceId(timeout);
+    }
+
+    @Override
+    public Map<MetricName, ? extends Metric> metrics() {
+        return delegate.metrics();
+    }
+
+    @Override
+    public List<PartitionInfo> partitionsFor(final String topic) {
+        return delegate.partitionsFor(topic);
+    }
+
+    @Override
+    public List<PartitionInfo> partitionsFor(final String topic, final 
Duration timeout) {
+        return delegate.partitionsFor(topic, timeout);
+    }
+
+    @Override
+    public Map<String, List<PartitionInfo>> listTopics() {
+        return delegate.listTopics();
+    }
+
+    @Override
+    public Map<String, List<PartitionInfo>> listTopics(final Duration timeout) 
{
+        return delegate.listTopics(timeout);
+    }
+
+    @Override
+    public Set<TopicPartition> paused() {
+        return delegate.paused();
+    }
+
+    @Override
+    public void pause(final Collection<TopicPartition> partitions) {
+        delegate.pause(partitions);
+    }
+
+    @Override
+    public void resume(final Collection<TopicPartition> partitions) {
+        delegate.resume(partitions);
+    }
+
+    @Override
+    public Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes(final 
Map<TopicPartition, Long> timestampsToSearch) {
+        return delegate.offsetsForTimes(timestampsToSearch);
+    }
+
+    @Override
+    public Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes(final 
Map<TopicPartition, Long> timestampsToSearch, final Duration timeout) {
+        return delegate.offsetsForTimes(timestampsToSearch, timeout);
+    }
+
+    @Override
+    public Map<TopicPartition, Long> beginningOffsets(final 
Collection<TopicPartition> partitions) {
+        return delegate.beginningOffsets(partitions);
+    }
+
+    @Override
+    public Map<TopicPartition, Long> beginningOffsets(final 
Collection<TopicPartition> partitions, final Duration timeout) {
+        return delegate.beginningOffsets(partitions, timeout);
+    }
+
+    @Override
+    public Map<TopicPartition, Long> endOffsets(final 
Collection<TopicPartition> partitions) {
+        return delegate.endOffsets(partitions);
+    }
+
+    @Override
+    public Map<TopicPartition, Long> endOffsets(final 
Collection<TopicPartition> partitions, final Duration timeout) {
+        return delegate.endOffsets(partitions, timeout);
+    }
+
+    @Override
+    public OptionalLong currentLag(final TopicPartition topicPartition) {
+        return delegate.currentLag(topicPartition);
+    }
+
+    @Override
+    public ConsumerGroupMetadata groupMetadata() {
+        return delegate.groupMetadata();
+    }
+
+    @Override
+    public void enforceRebalance() {
+        delegate.enforceRebalance();
+    }
+
+    @Override
+    public void enforceRebalance(final String reason) {
+        delegate.enforceRebalance(reason);
+    }
+
+    @Override
+    public void close() {
+        delegate.close();
+    }
+
+    @Deprecated
+    @Override
+    public void close(final Duration timeout) {
+        delegate.close(timeout);
+    }
+
+    @Override
+    public void close(final CloseOptions option) {
+        delegate.close(option);
+    }
+
+    @Override
+    public void wakeup() {
+        delegate.wakeup();
+    }
+}
\ No newline at end of file
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 1944f25f9df..fdc5e8df4bc 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -55,6 +55,7 @@ import org.apache.kafka.streams.ThreadMetadata;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
+import org.apache.kafka.streams.internals.ConsumerWrapper;
 import org.apache.kafka.streams.internals.metrics.ClientMetrics;
 import 
org.apache.kafka.streams.internals.metrics.StreamsThreadMetricsDelegatingReporter;
 import org.apache.kafka.streams.processor.StandbyUpdateListener;
@@ -550,11 +551,16 @@ public class StreamThread extends Thread implements 
ProcessingThread {
             );
             final ByteArrayDeserializer keyDeserializer = new 
ByteArrayDeserializer();
             final ByteArrayDeserializer valueDeserializer = new 
ByteArrayDeserializer();
+
             return new MainConsumerSetup(
-                new AsyncKafkaConsumer<>(
-                    new 
ConsumerConfig(ConsumerConfig.appendDeserializerToConfig(consumerConfigs, 
keyDeserializer, valueDeserializer)),
-                    keyDeserializer,
-                    valueDeserializer,
+                maybeWrapConsumer(
+                    consumerConfigs,
+                    new AsyncKafkaConsumer<>(
+                        new 
ConsumerConfig(ConsumerConfig.appendDeserializerToConfig(consumerConfigs, 
keyDeserializer, valueDeserializer)),
+                        keyDeserializer,
+                        valueDeserializer,
+                        streamsRebalanceData
+                    ),
                     streamsRebalanceData
                 ),
                 streamsRebalanceData
@@ -567,6 +573,32 @@ public class StreamThread extends Thread implements 
ProcessingThread {
         }
     }
 
+    private static Consumer<byte[], byte[]> maybeWrapConsumer(final 
Map<String, Object> config,
+                                                              final 
AsyncKafkaConsumer<byte[], byte[]> consumer,
+                                                              final 
Optional<StreamsRebalanceData> streamsRebalanceData) {
+        final Object o = config.get(InternalConfig.INTERNAL_CONSUMER_WRAPPER);
+        if (o == null) {
+            return consumer;
+        }
+
+        final ConsumerWrapper wrapper;
+        if (o instanceof String) {
+            try {
+                wrapper = Utils.newInstance((String) o, ConsumerWrapper.class);
+            } catch (final ClassNotFoundException e) {
+                throw new IllegalArgumentException(e);
+            }
+        } else if (o instanceof Class<?>) {
+            wrapper = (ConsumerWrapper) Utils.newInstance((Class<?>) o);
+        } else {
+            throw new IllegalArgumentException("Internal config " + 
InternalConfig.INTERNAL_CONSUMER_WRAPPER + " must be a class or class name");
+        }
+
+        wrapper.wrapConsumer(consumer, config, streamsRebalanceData);
+
+        return wrapper;
+    }
+
     private static class MainConsumerSetup {
         public final Consumer<byte[], byte[]> mainConsumer;
         public final Optional<StreamsRebalanceData> streamsRebalanceData;
@@ -1105,7 +1137,10 @@ public class StreamThread extends Thread implements 
ProcessingThread {
             mainConsumer.subscribe(topologyMetadata.sourceTopicPattern(), 
rebalanceListener);
         } else {
             if (streamsRebalanceData.isPresent()) {
-                ((AsyncKafkaConsumer<byte[], byte[]>) mainConsumer).subscribe(
+                final AsyncKafkaConsumer<byte[], byte[]> consumer = 
mainConsumer instanceof ConsumerWrapper
+                    ? ((ConsumerWrapper) mainConsumer).consumer()
+                    : (AsyncKafkaConsumer<byte[], byte[]>) mainConsumer;
+                consumer.subscribe(
                     topologyMetadata.allFullSourceTopicNames(),
                     new DefaultStreamsRebalanceListener(
                         log,
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 96090aa32fa..54230d11d3b 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -69,6 +69,7 @@ import 
org.apache.kafka.streams.errors.LogAndContinueExceptionHandler;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
+import org.apache.kafka.streams.internals.ConsumerWrapper;
 import org.apache.kafka.streams.kstream.Consumed;
 import org.apache.kafka.streams.kstream.Materialized;
 import org.apache.kafka.streams.kstream.internals.ConsumedInternal;
@@ -3925,6 +3926,38 @@ public class StreamThreadTest {
         );
     }
 
+    @ParameterizedTest
+    @MethodSource("data")
+    public void shouldWrapMainConsumerFromClassConfig(final boolean 
stateUpdaterEnabled, final boolean processingThreadsEnabled) {
+        final Properties streamsConfigProps = configProps(false, 
stateUpdaterEnabled, processingThreadsEnabled);
+        streamsConfigProps.put(StreamsConfig.GROUP_PROTOCOL_CONFIG, "streams");
+        streamsConfigProps.put(InternalConfig.INTERNAL_CONSUMER_WRAPPER, 
TestWrapper.class);
+
+        thread = createStreamThread("clientId", new 
StreamsConfig(streamsConfigProps));
+
+        assertInstanceOf(
+            AsyncKafkaConsumer.class,
+            assertInstanceOf(TestWrapper.class, 
thread.mainConsumer()).consumer()
+        );
+    }
+
+    @ParameterizedTest
+    @MethodSource("data")
+    public void shouldWrapMainConsumerFromStringConfig(final boolean 
stateUpdaterEnabled, final boolean processingThreadsEnabled) {
+        final Properties streamsConfigProps = configProps(false, 
stateUpdaterEnabled, processingThreadsEnabled);
+        streamsConfigProps.put(StreamsConfig.GROUP_PROTOCOL_CONFIG, "streams");
+        streamsConfigProps.put(InternalConfig.INTERNAL_CONSUMER_WRAPPER, 
TestWrapper.class.getName());
+
+        thread = createStreamThread("clientId", new 
StreamsConfig(streamsConfigProps));
+
+        assertInstanceOf(
+            AsyncKafkaConsumer.class,
+            assertInstanceOf(TestWrapper.class, 
thread.mainConsumer()).consumer()
+        );
+    }
+
+    public static final class TestWrapper extends ConsumerWrapper { }
+
     private StreamThread setUpThread(final Properties streamsConfigProps) {
         final StreamsConfig config = new StreamsConfig(streamsConfigProps);
         final ConsumerGroupMetadata consumerGroupMetadata = 
Mockito.mock(ConsumerGroupMetadata.class);


Reply via email to