This is an automated email from the ASF dual-hosted git repository.
lucasbru pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/trunk by this push:
new 34e426b5178 KAFKA-19271: Add internal ConsumerWrapper (#19697)
34e426b5178 is described below
commit 34e426b5178a5376668df6a2c6cf130fce78d431
Author: Matthias J. Sax <[email protected]>
AuthorDate: Fri May 16 02:57:37 2025 -0700
KAFKA-19271: Add internal ConsumerWrapper (#19697)
With KIP-1071 enabled, the main consumer is created differently, side
stepping `KafkaClientSupplier`.
To allow injection test wrappers, we add an internal ConsumerWrapper,
until we define a new public interface.
Reviewers: Lucas Brutschy <[email protected]>
---
.../KafkaStreamsTelemetryIntegrationTest.java | 148 +++++++---
.../org/apache/kafka/streams/StreamsConfig.java | 3 +
.../kafka/streams/internals/ConsumerWrapper.java | 316 +++++++++++++++++++++
.../streams/processor/internals/StreamThread.java | 45 ++-
.../processor/internals/StreamThreadTest.java | 33 +++
5 files changed, 508 insertions(+), 37 deletions(-)
diff --git
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/KafkaStreamsTelemetryIntegrationTest.java
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/KafkaStreamsTelemetryIntegrationTest.java
index 0af02321a67..eefb4e3e287 100644
---
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/KafkaStreamsTelemetryIntegrationTest.java
+++
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/KafkaStreamsTelemetryIntegrationTest.java
@@ -22,6 +22,8 @@ import org.apache.kafka.clients.admin.AdminClientConfig;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.consumer.KafkaConsumer;
+import org.apache.kafka.clients.consumer.internals.AsyncKafkaConsumer;
+import org.apache.kafka.clients.consumer.internals.StreamsRebalanceData;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.Producer;
import org.apache.kafka.common.Metric;
@@ -46,6 +48,7 @@ import org.apache.kafka.streams.StreamsConfig;
import org.apache.kafka.streams.Topology;
import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.internals.ConsumerWrapper;
import org.apache.kafka.streams.kstream.Consumed;
import org.apache.kafka.streams.kstream.Produced;
import org.apache.kafka.streams.processor.api.Processor;
@@ -62,7 +65,6 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
-import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
@@ -110,13 +112,23 @@ public class KafkaStreamsTelemetryIntegrationTest {
private KeyValueIterator<String, String> globalStoreIterator;
private static EmbeddedKafkaCluster cluster;
- private static final List<TestingMetricsInterceptingConsumer<byte[],
byte[]>> INTERCEPTING_CONSUMERS = new ArrayList<>();
+ private static final List<TestingMetricsInterceptor>
INTERCEPTING_CONSUMERS = new ArrayList<>();
private static final List<TestingMetricsInterceptingAdminClient>
INTERCEPTING_ADMIN_CLIENTS = new ArrayList<>();
private static final int NUM_BROKERS = 3;
private static final int FIRST_INSTANCE_CLIENT = 0;
private static final int SECOND_INSTANCE_CLIENT = 1;
private static final Logger LOG =
LoggerFactory.getLogger(KafkaStreamsTelemetryIntegrationTest.class);
+ static Stream<Arguments> recordingLevelParameters() {
+ return Stream.of(
+ Arguments.of("INFO", "classic"),
+ Arguments.of("DEBUG", "classic"),
+ Arguments.of("TRACE", "classic"),
+ Arguments.of("INFO", "streams"),
+ Arguments.of("DEBUG", "streams"),
+ Arguments.of("TRACE", "streams")
+ );
+ }
@BeforeAll
public static void startCluster() throws IOException {
@@ -160,9 +172,9 @@ public class KafkaStreamsTelemetryIntegrationTest {
}
@ParameterizedTest
- @ValueSource(strings = {"INFO", "DEBUG", "TRACE"})
- public void shouldPushGlobalThreadMetricsToBroker(final String
recordingLevel) throws Exception {
- streamsApplicationProperties = props(true);
+ @MethodSource("recordingLevelParameters")
+ public void shouldPushGlobalThreadMetricsToBroker(final String
recordingLevel, final String groupProtocol) throws Exception {
+ streamsApplicationProperties = props(true, groupProtocol);
streamsApplicationProperties.put(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG,
recordingLevel);
final Topology topology = simpleTopology(true);
subscribeForStreamsMetrics();
@@ -198,10 +210,10 @@ public class KafkaStreamsTelemetryIntegrationTest {
}
@ParameterizedTest
- @ValueSource(strings = {"INFO", "DEBUG", "TRACE"})
- public void shouldPushMetricsToBroker(final String recordingLevel) throws
Exception {
+ @MethodSource("recordingLevelParameters")
+ public void shouldPushMetricsToBroker(final String recordingLevel, final
String groupProtocol) throws Exception {
// End-to-end test validating metrics pushed to broker
- streamsApplicationProperties = props(true);
+ streamsApplicationProperties = props(true, groupProtocol);
streamsApplicationProperties.put(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG,
recordingLevel);
final Topology topology = simpleTopology(false);
subscribeForStreamsMetrics();
@@ -263,9 +275,9 @@ public class KafkaStreamsTelemetryIntegrationTest {
@ParameterizedTest
@MethodSource("singleAndMultiTaskParameters")
- public void shouldPassMetrics(final String topologyType, final boolean
stateUpdaterEnabled) throws Exception {
+ public void shouldPassMetrics(final String topologyType, final boolean
stateUpdaterEnabled, final String groupProtocol) throws Exception {
// Streams metrics should get passed to Admin and Consumer
- streamsApplicationProperties = props(stateUpdaterEnabled);
+ streamsApplicationProperties = props(stateUpdaterEnabled,
groupProtocol);
final Topology topology = topologyType.equals("simple") ?
simpleTopology(false) : complexTopology();
try (final KafkaStreams streams = new KafkaStreams(topology,
streamsApplicationProperties)) {
@@ -279,7 +291,7 @@ public class KafkaStreamsTelemetryIntegrationTest {
- final List<MetricName> consumerPassedStreamThreadMetricNames =
INTERCEPTING_CONSUMERS.get(FIRST_INSTANCE_CLIENT).passedMetrics.stream().map(KafkaMetric::metricName).toList();
+ final List<MetricName> consumerPassedStreamThreadMetricNames =
INTERCEPTING_CONSUMERS.get(FIRST_INSTANCE_CLIENT).passedMetrics().stream().map(KafkaMetric::metricName).toList();
final List<MetricName> adminPassedStreamClientMetricNames =
INTERCEPTING_ADMIN_CLIENTS.get(FIRST_INSTANCE_CLIENT).passedMetrics.stream().map(KafkaMetric::metricName).toList();
@@ -293,14 +305,14 @@ public class KafkaStreamsTelemetryIntegrationTest {
@ParameterizedTest
@MethodSource("multiTaskParameters")
- public void shouldPassCorrectMetricsDynamicInstances(final boolean
stateUpdaterEnabled) throws Exception {
+ public void shouldPassCorrectMetricsDynamicInstances(final boolean
stateUpdaterEnabled, final String groupProtocol) throws Exception {
// Correct streams metrics should get passed with dynamic membership
- streamsApplicationProperties = props(stateUpdaterEnabled);
+ streamsApplicationProperties = props(stateUpdaterEnabled,
groupProtocol);
streamsApplicationProperties.put(StreamsConfig.STATE_DIR_CONFIG,
TestUtils.tempDirectory(appId).getPath() + "-ks1");
streamsApplicationProperties.put(StreamsConfig.CLIENT_ID_CONFIG, appId
+ "-ks1");
- streamsSecondApplicationProperties = props(stateUpdaterEnabled);
+ streamsSecondApplicationProperties = props(stateUpdaterEnabled,
groupProtocol);
streamsSecondApplicationProperties.put(StreamsConfig.STATE_DIR_CONFIG,
TestUtils.tempDirectory(appId).getPath() + "-ks2");
streamsSecondApplicationProperties.put(StreamsConfig.CLIENT_ID_CONFIG,
appId + "-ks2");
@@ -312,7 +324,7 @@ public class KafkaStreamsTelemetryIntegrationTest {
final List<MetricName> streamsTaskMetricNames =
streamsOne.metrics().values().stream().map(Metric::metricName)
.filter(metricName ->
metricName.tags().containsKey("task-id")).toList();
- final List<MetricName> consumerPassedStreamTaskMetricNames =
INTERCEPTING_CONSUMERS.get(FIRST_INSTANCE_CLIENT).passedMetrics.stream().map(KafkaMetric::metricName)
+ final List<MetricName> consumerPassedStreamTaskMetricNames =
INTERCEPTING_CONSUMERS.get(FIRST_INSTANCE_CLIENT).passedMetrics().stream().map(KafkaMetric::metricName)
.filter(metricName ->
metricName.tags().containsKey("task-id")).toList();
/*
@@ -349,9 +361,9 @@ public class KafkaStreamsTelemetryIntegrationTest {
.filter(metricName ->
metricName.group().equals("stream-state-metrics")).toList();
final List<MetricName> consumerOnePassedTaskMetrics =
INTERCEPTING_CONSUMERS.get(FIRST_INSTANCE_CLIENT)
-
.passedMetrics.stream().map(KafkaMetric::metricName).filter(metricName ->
metricName.tags().containsKey("task-id")).toList();
+
.passedMetrics().stream().map(KafkaMetric::metricName).filter(metricName ->
metricName.tags().containsKey("task-id")).toList();
final List<MetricName> consumerOnePassedStateMetrics =
INTERCEPTING_CONSUMERS.get(FIRST_INSTANCE_CLIENT)
-
.passedMetrics.stream().map(KafkaMetric::metricName).filter(metricName ->
metricName.group().equals("stream-state-metrics")).toList();
+
.passedMetrics().stream().map(KafkaMetric::metricName).filter(metricName ->
metricName.group().equals("stream-state-metrics")).toList();
final List<MetricName> streamsTwoTaskMetrics =
streamsTwo.metrics().values().stream().map(Metric::metricName)
.filter(metricName ->
metricName.tags().containsKey("task-id")).toList();
@@ -359,9 +371,9 @@ public class KafkaStreamsTelemetryIntegrationTest {
.filter(metricName ->
metricName.group().equals("stream-state-metrics")).toList();
final List<MetricName> consumerTwoPassedTaskMetrics =
INTERCEPTING_CONSUMERS.get(SECOND_INSTANCE_CLIENT)
-
.passedMetrics.stream().map(KafkaMetric::metricName).filter(metricName ->
metricName.tags().containsKey("task-id")).toList();
+
.passedMetrics().stream().map(KafkaMetric::metricName).filter(metricName ->
metricName.tags().containsKey("task-id")).toList();
final List<MetricName> consumerTwoPassedStateMetrics =
INTERCEPTING_CONSUMERS.get(SECOND_INSTANCE_CLIENT)
-
.passedMetrics.stream().map(KafkaMetric::metricName).filter(metricName ->
metricName.group().equals("stream-state-metrics")).toList();
+
.passedMetrics().stream().map(KafkaMetric::metricName).filter(metricName ->
metricName.group().equals("stream-state-metrics")).toList();
/*
Confirm pre-existing KafkaStreams instance one only passes
metrics for its tasks and has no metrics for previous tasks
*/
@@ -391,10 +403,11 @@ public class KafkaStreamsTelemetryIntegrationTest {
}
}
- @Test
- public void passedMetricsShouldNotLeakIntoClientMetrics() throws Exception
{
+ @ParameterizedTest
+ @ValueSource(strings = {"classic", "streams"})
+ public void passedMetricsShouldNotLeakIntoClientMetrics(final String
groupProtocol) throws Exception {
// Streams metrics should not be visible in client metrics
- streamsApplicationProperties = props(true);
+ streamsApplicationProperties = props(true, groupProtocol);
final Topology topology = complexTopology();
try (final KafkaStreams streams = new KafkaStreams(topology,
streamsApplicationProperties)) {
@@ -423,6 +436,7 @@ public class KafkaStreamsTelemetryIntegrationTest {
clientMetricsService.alterClientMetrics(commandOptions);
}
}
+
private List<String> getTaskIdsAsStrings(final KafkaStreams streams) {
return streams.metadataForLocalThreads().stream()
.flatMap(threadMeta -> threadMeta.activeTasks().stream()
@@ -431,19 +445,32 @@ public class KafkaStreamsTelemetryIntegrationTest {
}
private static Stream<Arguments> singleAndMultiTaskParameters() {
- return Stream.of(Arguments.of("simple", true),
- Arguments.of("simple", false),
- Arguments.of("complex", true),
- Arguments.of("complex", false));
+ return Stream.of(
+ Arguments.of("simple", true, "classic"),
+ Arguments.of("simple", false, "classic"),
+ Arguments.of("complex", true, "classic"),
+ Arguments.of("complex", false, "classic"),
+ Arguments.of("simple", true, "streams"),
+ Arguments.of("simple", false, "streams"),
+ Arguments.of("complex", true, "streams"),
+ Arguments.of("complex", false, "streams")
+ );
}
private static Stream<Arguments> multiTaskParameters() {
- return Stream.of(Arguments.of(true),
- Arguments.of(false));
+ return Stream.of(
+ Arguments.of(true, "classic"),
+ Arguments.of(false, "classic"),
+ Arguments.of(true, "streams"),
+ Arguments.of(false, "streams")
+ );
}
- private Properties props(final boolean stateUpdaterEnabled) {
- return
props(mkObjectProperties(mkMap(mkEntry(StreamsConfig.InternalConfig.STATE_UPDATER_ENABLED,
stateUpdaterEnabled))));
+ private Properties props(final boolean stateUpdaterEnabled, final String
groupProtocol) {
+ return props(mkObjectProperties(mkMap(
+ mkEntry(StreamsConfig.InternalConfig.STATE_UPDATER_ENABLED,
stateUpdaterEnabled),
+ mkEntry(StreamsConfig.GROUP_PROTOCOL_CONFIG, groupProtocol)
+ )));
}
private Properties props(final Properties extraProperties) {
@@ -455,6 +482,7 @@ public class KafkaStreamsTelemetryIntegrationTest {
streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG,
Serdes.StringSerde.class);
streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG,
Serdes.StringSerde.class);
streamsConfiguration.put(StreamsConfig.DEFAULT_CLIENT_SUPPLIER_CONFIG,
TestClientSupplier.class);
+
streamsConfiguration.put(StreamsConfig.InternalConfig.INTERNAL_CONSUMER_WRAPPER,
TestConsumerWrapper.class);
streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG,
"earliest");
streamsConfiguration.putAll(extraProperties);
return streamsConfiguration;
@@ -545,9 +573,23 @@ public class KafkaStreamsTelemetryIntegrationTest {
}
}
- public static class TestingMetricsInterceptingConsumer<K, V> extends
KafkaConsumer<K, V> {
+ public static class TestConsumerWrapper extends ConsumerWrapper {
+ @Override
+ public void wrapConsumer(final AsyncKafkaConsumer<byte[], byte[]>
delegate, final Map<String, Object> config, final
Optional<StreamsRebalanceData> streamsRebalanceData) {
+ final TestingMetricsInterceptingAsynConsumer<byte[], byte[]>
consumer = new TestingMetricsInterceptingAsynConsumer<>(config, new
ByteArrayDeserializer(), new ByteArrayDeserializer(), streamsRebalanceData);
+ INTERCEPTING_CONSUMERS.add(consumer);
+
+ super.wrapConsumer(consumer, config, streamsRebalanceData);
+ }
+ }
+
+ public interface TestingMetricsInterceptor {
+ List<KafkaMetric> passedMetrics();
+ Map<MetricName, ? extends Metric> metrics();
+ }
- public List<KafkaMetric> passedMetrics = new ArrayList<>();
+ public static class TestingMetricsInterceptingConsumer<K, V> extends
KafkaConsumer<K, V> implements TestingMetricsInterceptor {
+ private final List<KafkaMetric> passedMetrics = new ArrayList<>();
public TestingMetricsInterceptingConsumer(final Map<String, Object>
configs, final Deserializer<K> keyDeserializer, final Deserializer<V>
valueDeserializer) {
super(configs, keyDeserializer, valueDeserializer);
@@ -564,6 +606,48 @@ public class KafkaStreamsTelemetryIntegrationTest {
passedMetrics.remove(metric);
super.unregisterMetricFromSubscription(metric);
}
+
+ @Override
+ public List<KafkaMetric> passedMetrics() {
+ return passedMetrics;
+ }
+ }
+
+ public static class TestingMetricsInterceptingAsynConsumer<K, V> extends
AsyncKafkaConsumer<K, V> implements TestingMetricsInterceptor {
+ private final List<KafkaMetric> passedMetrics = new ArrayList<>();
+
+ public TestingMetricsInterceptingAsynConsumer(
+ final Map<String, Object> configs,
+ final Deserializer<K> keyDeserializer,
+ final Deserializer<V> valueDeserializer,
+ final Optional<StreamsRebalanceData> streamsRebalanceData
+ ) {
+ super(
+ new ConsumerConfig(
+ ConsumerConfig.appendDeserializerToConfig(configs,
keyDeserializer, valueDeserializer)
+ ),
+ keyDeserializer,
+ valueDeserializer,
+ streamsRebalanceData
+ );
+ }
+
+ @Override
+ public void registerMetricForSubscription(final KafkaMetric metric) {
+ passedMetrics.add(metric);
+ super.registerMetricForSubscription(metric);
+ }
+
+ @Override
+ public void unregisterMetricFromSubscription(final KafkaMetric metric)
{
+ passedMetrics.remove(metric);
+ super.unregisterMetricFromSubscription(metric);
+ }
+
+ @Override
+ public List<KafkaMetric> passedMetrics() {
+ return passedMetrics;
+ }
}
public static class TelemetryPlugin implements ClientTelemetry,
MetricsReporter, ClientTelemetryReceiver {
diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
index e19cb03b207..c8fa251b6be 100644
--- a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
+++ b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
@@ -1320,6 +1320,9 @@ public class StreamsConfig extends AbstractConfig {
// This is settable in the main Streams config, but it's a private API
for testing
public static final String ASSIGNMENT_LISTENER =
"__assignment.listener__";
+ // This is settable in the main Streams config, but it's a private API
for testing
+ public static final String INTERNAL_CONSUMER_WRAPPER =
"__internal.consumer.wrapper__";
+
// Private API used to control the emit latency for left/outer join
results (https://issues.apache.org/jira/browse/KAFKA-10847)
public static final String
EMIT_INTERVAL_MS_KSTREAMS_OUTER_JOIN_SPURIOUS_RESULTS_FIX =
"__emit.interval.ms.kstreams.outer.join.spurious.results.fix__";
diff --git
a/streams/src/main/java/org/apache/kafka/streams/internals/ConsumerWrapper.java
b/streams/src/main/java/org/apache/kafka/streams/internals/ConsumerWrapper.java
new file mode 100644
index 00000000000..20cd7cb84d7
--- /dev/null
+++
b/streams/src/main/java/org/apache/kafka/streams/internals/ConsumerWrapper.java
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.internals;
+
+import org.apache.kafka.clients.consumer.CloseOptions;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerGroupMetadata;
+import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.clients.consumer.OffsetAndTimestamp;
+import org.apache.kafka.clients.consumer.OffsetCommitCallback;
+import org.apache.kafka.clients.consumer.SubscriptionPattern;
+import org.apache.kafka.clients.consumer.internals.AsyncKafkaConsumer;
+import org.apache.kafka.clients.consumer.internals.StreamsRebalanceData;
+import org.apache.kafka.common.Metric;
+import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.metrics.KafkaMetric;
+
+import java.time.Duration;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.OptionalLong;
+import java.util.Set;
+import java.util.regex.Pattern;
+
+public abstract class ConsumerWrapper implements Consumer<byte[], byte[]> {
+ protected AsyncKafkaConsumer<byte[], byte[]> delegate;
+
+ public void wrapConsumer(
+ final AsyncKafkaConsumer<byte[], byte[]> delegate,
+ final Map<String, Object> config,
+ final Optional<StreamsRebalanceData> streamsRebalanceData
+ ) {
+ this.delegate = delegate;
+ }
+
+ public AsyncKafkaConsumer<byte[], byte[]> consumer() {
+ return delegate;
+ }
+
+ @Override
+ public Set<TopicPartition> assignment() {
+ return delegate.assignment();
+ }
+
+ @Override
+ public Set<String> subscription() {
+ return delegate.subscription();
+ }
+
+ @Override
+ public void subscribe(final Collection<String> topics) {
+ delegate.subscribe(topics);
+ }
+
+ @Override
+ public void subscribe(final Collection<String> topics, final
ConsumerRebalanceListener callback) {
+ delegate.subscribe(topics, callback);
+ }
+
+ @Override
+ public void assign(final Collection<TopicPartition> partitions) {
+ delegate.assign(partitions);
+ }
+
+ @Override
+ public void subscribe(final Pattern pattern, final
ConsumerRebalanceListener callback) {
+ delegate.subscribe(pattern, callback);
+ }
+
+ @Override
+ public void subscribe(final Pattern pattern) {
+ delegate.subscribe(pattern);
+ }
+
+ @Override
+ public void subscribe(final SubscriptionPattern pattern, final
ConsumerRebalanceListener callback) {
+ delegate.subscribe(pattern, callback);
+ }
+
+ @Override
+ public void subscribe(final SubscriptionPattern pattern) {
+ delegate.subscribe(pattern);
+ }
+
+ @Override
+ public void unsubscribe() {
+ delegate.unsubscribe();
+ }
+
+ @Override
+ public ConsumerRecords<byte[], byte[]> poll(final Duration timeout) {
+ return delegate.poll(timeout);
+ }
+
+ @Override
+ public void commitSync() {
+ delegate.commitSync();
+ }
+
+ @Override
+ public void commitSync(final Duration timeout) {
+ delegate.commitSync(timeout);
+ }
+
+ @Override
+ public void commitSync(final Map<TopicPartition, OffsetAndMetadata>
offsets) {
+ delegate.commitSync(offsets);
+ }
+
+ @Override
+ public void commitSync(final Map<TopicPartition, OffsetAndMetadata>
offsets, final Duration timeout) {
+ delegate.commitSync(offsets, timeout);
+ }
+
+ @Override
+ public void commitAsync() {
+ delegate.commitAsync();
+ }
+
+ @Override
+ public void commitAsync(final OffsetCommitCallback callback) {
+ delegate.commitAsync(callback);
+ }
+
+ @Override
+ public void commitAsync(final Map<TopicPartition, OffsetAndMetadata>
offsets, final OffsetCommitCallback callback) {
+ delegate.commitAsync(offsets, callback);
+ }
+
+ @Override
+ public void registerMetricForSubscription(final KafkaMetric metric) {
+ delegate.registerMetricForSubscription(metric);
+ }
+
+ @Override
+ public void unregisterMetricFromSubscription(final KafkaMetric metric) {
+ delegate.unregisterMetricFromSubscription(metric);
+ }
+
+ @Override
+ public void seek(final TopicPartition partition, final long offset) {
+ delegate.seek(partition, offset);
+ }
+
+ @Override
+ public void seek(final TopicPartition partition, final OffsetAndMetadata
offsetAndMetadata) {
+ delegate.seek(partition, offsetAndMetadata);
+ }
+
+ @Override
+ public void seekToBeginning(final Collection<TopicPartition> partitions) {
+ delegate.seekToBeginning(partitions);
+ }
+
+ @Override
+ public void seekToEnd(final Collection<TopicPartition> partitions) {
+ delegate.seekToEnd(partitions);
+ }
+
+ @Override
+ public long position(final TopicPartition partition) {
+ return delegate.position(partition);
+ }
+
+ @Override
+ public long position(final TopicPartition partition, final Duration
timeout) {
+ return delegate.position(partition, timeout);
+ }
+
+ @Override
+ public Map<TopicPartition, OffsetAndMetadata> committed(final
Set<TopicPartition> partitions) {
+ return delegate.committed(partitions);
+ }
+
+ @Override
+ public Map<TopicPartition, OffsetAndMetadata> committed(final
Set<TopicPartition> partitions, final Duration timeout) {
+ return delegate.committed(partitions, timeout);
+ }
+
+ @Override
+ public Uuid clientInstanceId(final Duration timeout) {
+ return delegate.clientInstanceId(timeout);
+ }
+
+ @Override
+ public Map<MetricName, ? extends Metric> metrics() {
+ return delegate.metrics();
+ }
+
+ @Override
+ public List<PartitionInfo> partitionsFor(final String topic) {
+ return delegate.partitionsFor(topic);
+ }
+
+ @Override
+ public List<PartitionInfo> partitionsFor(final String topic, final
Duration timeout) {
+ return delegate.partitionsFor(topic, timeout);
+ }
+
+ @Override
+ public Map<String, List<PartitionInfo>> listTopics() {
+ return delegate.listTopics();
+ }
+
+ @Override
+ public Map<String, List<PartitionInfo>> listTopics(final Duration timeout)
{
+ return delegate.listTopics(timeout);
+ }
+
+ @Override
+ public Set<TopicPartition> paused() {
+ return delegate.paused();
+ }
+
+ @Override
+ public void pause(final Collection<TopicPartition> partitions) {
+ delegate.pause(partitions);
+ }
+
+ @Override
+ public void resume(final Collection<TopicPartition> partitions) {
+ delegate.resume(partitions);
+ }
+
+ @Override
+ public Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes(final
Map<TopicPartition, Long> timestampsToSearch) {
+ return delegate.offsetsForTimes(timestampsToSearch);
+ }
+
+ @Override
+ public Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes(final
Map<TopicPartition, Long> timestampsToSearch, final Duration timeout) {
+ return delegate.offsetsForTimes(timestampsToSearch, timeout);
+ }
+
+ @Override
+ public Map<TopicPartition, Long> beginningOffsets(final
Collection<TopicPartition> partitions) {
+ return delegate.beginningOffsets(partitions);
+ }
+
+ @Override
+ public Map<TopicPartition, Long> beginningOffsets(final
Collection<TopicPartition> partitions, final Duration timeout) {
+ return delegate.beginningOffsets(partitions, timeout);
+ }
+
+ @Override
+ public Map<TopicPartition, Long> endOffsets(final
Collection<TopicPartition> partitions) {
+ return delegate.endOffsets(partitions);
+ }
+
+ @Override
+ public Map<TopicPartition, Long> endOffsets(final
Collection<TopicPartition> partitions, final Duration timeout) {
+ return delegate.endOffsets(partitions, timeout);
+ }
+
+ @Override
+ public OptionalLong currentLag(final TopicPartition topicPartition) {
+ return delegate.currentLag(topicPartition);
+ }
+
+ @Override
+ public ConsumerGroupMetadata groupMetadata() {
+ return delegate.groupMetadata();
+ }
+
+ @Override
+ public void enforceRebalance() {
+ delegate.enforceRebalance();
+ }
+
+ @Override
+ public void enforceRebalance(final String reason) {
+ delegate.enforceRebalance(reason);
+ }
+
+ @Override
+ public void close() {
+ delegate.close();
+ }
+
+ @Deprecated
+ @Override
+ public void close(final Duration timeout) {
+ delegate.close(timeout);
+ }
+
+ @Override
+ public void close(final CloseOptions option) {
+ delegate.close(option);
+ }
+
+ @Override
+ public void wakeup() {
+ delegate.wakeup();
+ }
+}
\ No newline at end of file
diff --git
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 1944f25f9df..fdc5e8df4bc 100644
---
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -55,6 +55,7 @@ import org.apache.kafka.streams.ThreadMetadata;
import org.apache.kafka.streams.errors.StreamsException;
import org.apache.kafka.streams.errors.TaskCorruptedException;
import org.apache.kafka.streams.errors.TaskMigratedException;
+import org.apache.kafka.streams.internals.ConsumerWrapper;
import org.apache.kafka.streams.internals.metrics.ClientMetrics;
import
org.apache.kafka.streams.internals.metrics.StreamsThreadMetricsDelegatingReporter;
import org.apache.kafka.streams.processor.StandbyUpdateListener;
@@ -550,11 +551,16 @@ public class StreamThread extends Thread implements
ProcessingThread {
);
final ByteArrayDeserializer keyDeserializer = new
ByteArrayDeserializer();
final ByteArrayDeserializer valueDeserializer = new
ByteArrayDeserializer();
+
return new MainConsumerSetup(
- new AsyncKafkaConsumer<>(
- new
ConsumerConfig(ConsumerConfig.appendDeserializerToConfig(consumerConfigs,
keyDeserializer, valueDeserializer)),
- keyDeserializer,
- valueDeserializer,
+ maybeWrapConsumer(
+ consumerConfigs,
+ new AsyncKafkaConsumer<>(
+ new
ConsumerConfig(ConsumerConfig.appendDeserializerToConfig(consumerConfigs,
keyDeserializer, valueDeserializer)),
+ keyDeserializer,
+ valueDeserializer,
+ streamsRebalanceData
+ ),
streamsRebalanceData
),
streamsRebalanceData
@@ -567,6 +573,32 @@ public class StreamThread extends Thread implements
ProcessingThread {
}
}
+ private static Consumer<byte[], byte[]> maybeWrapConsumer(final
Map<String, Object> config,
+ final
AsyncKafkaConsumer<byte[], byte[]> consumer,
+ final
Optional<StreamsRebalanceData> streamsRebalanceData) {
+ final Object o = config.get(InternalConfig.INTERNAL_CONSUMER_WRAPPER);
+ if (o == null) {
+ return consumer;
+ }
+
+ final ConsumerWrapper wrapper;
+ if (o instanceof String) {
+ try {
+ wrapper = Utils.newInstance((String) o, ConsumerWrapper.class);
+ } catch (final ClassNotFoundException e) {
+ throw new IllegalArgumentException(e);
+ }
+ } else if (o instanceof Class<?>) {
+ wrapper = (ConsumerWrapper) Utils.newInstance((Class<?>) o);
+ } else {
+ throw new IllegalArgumentException("Internal config " +
InternalConfig.INTERNAL_CONSUMER_WRAPPER + " must be a class or class name");
+ }
+
+ wrapper.wrapConsumer(consumer, config, streamsRebalanceData);
+
+ return wrapper;
+ }
+
private static class MainConsumerSetup {
public final Consumer<byte[], byte[]> mainConsumer;
public final Optional<StreamsRebalanceData> streamsRebalanceData;
@@ -1105,7 +1137,10 @@ public class StreamThread extends Thread implements
ProcessingThread {
mainConsumer.subscribe(topologyMetadata.sourceTopicPattern(),
rebalanceListener);
} else {
if (streamsRebalanceData.isPresent()) {
- ((AsyncKafkaConsumer<byte[], byte[]>) mainConsumer).subscribe(
+ final AsyncKafkaConsumer<byte[], byte[]> consumer =
mainConsumer instanceof ConsumerWrapper
+ ? ((ConsumerWrapper) mainConsumer).consumer()
+ : (AsyncKafkaConsumer<byte[], byte[]>) mainConsumer;
+ consumer.subscribe(
topologyMetadata.allFullSourceTopicNames(),
new DefaultStreamsRebalanceListener(
log,
diff --git
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 96090aa32fa..54230d11d3b 100644
---
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -69,6 +69,7 @@ import
org.apache.kafka.streams.errors.LogAndContinueExceptionHandler;
import org.apache.kafka.streams.errors.StreamsException;
import org.apache.kafka.streams.errors.TaskCorruptedException;
import org.apache.kafka.streams.errors.TaskMigratedException;
+import org.apache.kafka.streams.internals.ConsumerWrapper;
import org.apache.kafka.streams.kstream.Consumed;
import org.apache.kafka.streams.kstream.Materialized;
import org.apache.kafka.streams.kstream.internals.ConsumedInternal;
@@ -3925,6 +3926,38 @@ public class StreamThreadTest {
);
}
+ @ParameterizedTest
+ @MethodSource("data")
+ public void shouldWrapMainConsumerFromClassConfig(final boolean
stateUpdaterEnabled, final boolean processingThreadsEnabled) {
+ final Properties streamsConfigProps = configProps(false,
stateUpdaterEnabled, processingThreadsEnabled);
+ streamsConfigProps.put(StreamsConfig.GROUP_PROTOCOL_CONFIG, "streams");
+ streamsConfigProps.put(InternalConfig.INTERNAL_CONSUMER_WRAPPER,
TestWrapper.class);
+
+ thread = createStreamThread("clientId", new
StreamsConfig(streamsConfigProps));
+
+ assertInstanceOf(
+ AsyncKafkaConsumer.class,
+ assertInstanceOf(TestWrapper.class,
thread.mainConsumer()).consumer()
+ );
+ }
+
+ @ParameterizedTest
+ @MethodSource("data")
+ public void shouldWrapMainConsumerFromStringConfig(final boolean
stateUpdaterEnabled, final boolean processingThreadsEnabled) {
+ final Properties streamsConfigProps = configProps(false,
stateUpdaterEnabled, processingThreadsEnabled);
+ streamsConfigProps.put(StreamsConfig.GROUP_PROTOCOL_CONFIG, "streams");
+ streamsConfigProps.put(InternalConfig.INTERNAL_CONSUMER_WRAPPER,
TestWrapper.class.getName());
+
+ thread = createStreamThread("clientId", new
StreamsConfig(streamsConfigProps));
+
+ assertInstanceOf(
+ AsyncKafkaConsumer.class,
+ assertInstanceOf(TestWrapper.class,
thread.mainConsumer()).consumer()
+ );
+ }
+
+ public static final class TestWrapper extends ConsumerWrapper { }
+
private StreamThread setUpThread(final Properties streamsConfigProps) {
final StreamsConfig config = new StreamsConfig(streamsConfigProps);
final ConsumerGroupMetadata consumerGroupMetadata =
Mockito.mock(ConsumerGroupMetadata.class);