This is an automated email from the ASF dual-hosted git repository.

cadonna pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 2a370ed7213 KAFKA-19037: Integrate consumer-side code with Streams 
(#19377)
2a370ed7213 is described below

commit 2a370ed7213c7de74acef8b66955c4b91ba791e2
Author: Bruno Cadonna <[email protected]>
AuthorDate: Wed Apr 9 13:26:51 2025 +0200

    KAFKA-19037: Integrate consumer-side code with Streams (#19377)
    
    The consumer adaptations for the new Streams rebalance protocol need to
    be integrated into the Streams code. This commit does the following:
    - creates an async Kafka consumer
      - with a Streams heartbeat request manager
      - with a Streams membership manager
    - integrates consumer code with the Streams membership manager and the
    Streams heartbeat request manager
    - processes the events from the consumer network thread (a.k.a.
    background thread)
      that request the invocation of the "on tasks revoked", "on  tasks
    assigned", and "on all tasks lost"
      callbacks
    - executes the callbacks
    - sends to the consumer network thread the events signalling the
    execution of the callbacks
    - adapts SmokeTestDriverIntegrationTest to use the new Streams rebalance
    protocol
    
    This commit misses some unit test coverage, but it also unblocks other
    work on trunk regarding the new Streams rebalance protocol.  The missing
    unit tests will be added soon.
    
    Reviewers: Lucas Brutschy <[email protected]>
---
 .../consumer/internals/AsyncKafkaConsumer.java     |  30 ++-
 .../internals/ConsumerDelegateCreator.java         |   3 +-
 .../consumer/internals/RequestManagers.java        |  93 ++++++---
 .../StreamsGroupHeartbeatRequestManager.java       |  21 +-
 .../consumer/internals/StreamsRebalanceData.java   |   5 +-
 .../internals/StreamsRebalanceEventsProcessor.java |  13 +-
 .../events/ApplicationEventProcessor.java          |  97 +++++++--
 .../consumer/internals/AsyncKafkaConsumerTest.java |  12 +-
 .../consumer/internals/RequestManagersTest.java    |   4 +-
 .../StreamsRebalanceEventsProcessorTest.java       |  18 +-
 .../events/ApplicationEventProcessorTest.java      |   5 +-
 .../SmokeTestDriverIntegrationTest.java            |  44 +++-
 .../org/apache/kafka/streams/GroupProtocol.java    |  43 ++++
 .../org/apache/kafka/streams/StreamsConfig.java    |  24 +++
 .../DefaultStreamsGroupRebalanceCallbacks.java     | 130 ++++++++++++
 .../streams/processor/internals/StreamThread.java  | 229 ++++++++++++++++++++-
 .../apache/kafka/streams/StreamsConfigTest.java    |  20 ++
 .../processor/internals/StreamThreadTest.java      |  30 +++
 .../kafka/streams/tests/SmokeTestClient.java       |   4 +-
 19 files changed, 733 insertions(+), 92 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
index ba2e145d327..e421ef93191 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
@@ -280,10 +280,12 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
             setGroupAssignmentSnapshot(partitions);
         }
     };
-
-    AsyncKafkaConsumer(final ConsumerConfig config,
-                       final Deserializer<K> keyDeserializer,
-                       final Deserializer<V> valueDeserializer) {
+    
+    public AsyncKafkaConsumer(final ConsumerConfig config,
+                              final Deserializer<K> keyDeserializer,
+                              final Deserializer<V> valueDeserializer,
+                              final Optional<StreamsRebalanceData> 
streamsRebalanceData,
+                              final Optional<StreamsRebalanceEventsProcessor> 
streamsRebalanceEventsProcessor) {
         this(
             config,
             keyDeserializer,
@@ -293,11 +295,14 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
             CompletableEventReaper::new,
             FetchCollector::new,
             ConsumerMetadata::new,
-            new LinkedBlockingQueue<>()
+            new LinkedBlockingQueue<>(),
+            streamsRebalanceData,
+            streamsRebalanceEventsProcessor
         );
     }
 
     // Visible for testing
+    @SuppressWarnings({"this-escape"})
     AsyncKafkaConsumer(final ConsumerConfig config,
                        final Deserializer<K> keyDeserializer,
                        final Deserializer<V> valueDeserializer,
@@ -306,7 +311,9 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
                        final CompletableEventReaperFactory 
backgroundEventReaperFactory,
                        final FetchCollectorFactory<K, V> fetchCollectorFactory,
                        final ConsumerMetadataFactory metadataFactory,
-                       final LinkedBlockingQueue<BackgroundEvent> 
backgroundEventQueue) {
+                       final LinkedBlockingQueue<BackgroundEvent> 
backgroundEventQueue,
+                       final Optional<StreamsRebalanceData> 
streamsRebalanceData,
+                       final Optional<StreamsRebalanceEventsProcessor> 
streamsRebalanceEventsProcessor) {
         try {
             GroupRebalanceConfig groupRebalanceConfig = new 
GroupRebalanceConfig(
                 config,
@@ -382,7 +389,9 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
                     clientTelemetryReporter,
                     metrics,
                     offsetCommitCallbackInvoker,
-                    memberStateListener
+                    memberStateListener,
+                    streamsRebalanceData,
+                    streamsRebalanceEventsProcessor
             );
             final Supplier<ApplicationEventProcessor> 
applicationEventProcessorSupplier = 
ApplicationEventProcessor.supplier(logContext,
                     metadata,
@@ -398,6 +407,9 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
                     requestManagersSupplier,
                     kafkaConsumerMetrics
             );
+            streamsRebalanceEventsProcessor.ifPresent(
+                processor -> 
processor.setApplicationEventHandler(applicationEventHandler)
+            );
 
             this.rebalanceListenerInvoker = new 
ConsumerRebalanceListenerInvoker(
                     logContext,
@@ -568,7 +580,9 @@ public class AsyncKafkaConsumer<K, V> implements 
ConsumerDelegate<K, V> {
             clientTelemetryReporter,
             metrics,
             offsetCommitCallbackInvoker,
-            memberStateListener
+            memberStateListener,
+            Optional.empty(),
+            Optional.empty()
         );
         Supplier<ApplicationEventProcessor> applicationEventProcessorSupplier 
= ApplicationEventProcessor.supplier(
                 logContext,
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerDelegateCreator.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerDelegateCreator.java
index 74592972b9d..60d12aec864 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerDelegateCreator.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerDelegateCreator.java
@@ -29,6 +29,7 @@ import org.apache.kafka.common.utils.Time;
 
 import java.util.List;
 import java.util.Locale;
+import java.util.Optional;
 
 /**
  * {@code ConsumerDelegateCreator} implements a quasi-factory pattern to allow 
the caller to remain unaware of the
@@ -60,7 +61,7 @@ public class ConsumerDelegateCreator {
             GroupProtocol groupProtocol = 
GroupProtocol.valueOf(config.getString(ConsumerConfig.GROUP_PROTOCOL_CONFIG).toUpperCase(Locale.ROOT));
 
             if (groupProtocol == GroupProtocol.CONSUMER)
-                return new AsyncKafkaConsumer<>(config, keyDeserializer, 
valueDeserializer);
+                return new AsyncKafkaConsumer<>(config, keyDeserializer, 
valueDeserializer, Optional.empty(), Optional.empty());
             else
                 return new ClassicKafkaConsumer<>(config, keyDeserializer, 
valueDeserializer);
         } catch (KafkaException e) {
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestManagers.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestManagers.java
index 9b9fb48482b..9f406bd2090 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestManagers.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestManagers.java
@@ -54,10 +54,12 @@ public class RequestManagers implements Closeable {
     public final Optional<ShareHeartbeatRequestManager> 
shareHeartbeatRequestManager;
     public final Optional<ConsumerMembershipManager> consumerMembershipManager;
     public final Optional<ShareMembershipManager> shareMembershipManager;
+    public final Optional<StreamsMembershipManager> streamsMembershipManager;
     public final OffsetsRequestManager offsetsRequestManager;
     public final TopicMetadataRequestManager topicMetadataRequestManager;
     public final FetchRequestManager fetchRequestManager;
     public final Optional<ShareConsumeRequestManager> 
shareConsumeRequestManager;
+    public final Optional<StreamsGroupHeartbeatRequestManager> 
streamsGroupHeartbeatRequestManager;
     private final List<Optional<? extends RequestManager>> entries;
     private final IdempotentCloser closer = new IdempotentCloser();
 
@@ -68,7 +70,9 @@ public class RequestManagers implements Closeable {
                            Optional<CoordinatorRequestManager> 
coordinatorRequestManager,
                            Optional<CommitRequestManager> commitRequestManager,
                            Optional<ConsumerHeartbeatRequestManager> 
heartbeatRequestManager,
-                           Optional<ConsumerMembershipManager> 
membershipManager) {
+                           Optional<ConsumerMembershipManager> 
membershipManager,
+                           Optional<StreamsGroupHeartbeatRequestManager> 
streamsGroupHeartbeatRequestManager,
+                           Optional<StreamsMembershipManager> 
streamsMembershipManager) {
         this.log = logContext.logger(RequestManagers.class);
         this.offsetsRequestManager = requireNonNull(offsetsRequestManager, 
"OffsetsRequestManager cannot be null");
         this.coordinatorRequestManager = coordinatorRequestManager;
@@ -78,7 +82,9 @@ public class RequestManagers implements Closeable {
         this.shareConsumeRequestManager = Optional.empty();
         this.consumerHeartbeatRequestManager = heartbeatRequestManager;
         this.shareHeartbeatRequestManager = Optional.empty();
+        this.streamsGroupHeartbeatRequestManager = 
streamsGroupHeartbeatRequestManager;
         this.consumerMembershipManager = membershipManager;
+        this.streamsMembershipManager = streamsMembershipManager;
         this.shareMembershipManager = Optional.empty();
 
         List<Optional<? extends RequestManager>> list = new ArrayList<>();
@@ -86,6 +92,8 @@ public class RequestManagers implements Closeable {
         list.add(commitRequestManager);
         list.add(heartbeatRequestManager);
         list.add(membershipManager);
+        list.add(streamsGroupHeartbeatRequestManager);
+        list.add(streamsMembershipManager);
         list.add(Optional.of(offsetsRequestManager));
         list.add(Optional.of(topicMetadataRequestManager));
         list.add(Optional.of(fetchRequestManager));
@@ -102,8 +110,10 @@ public class RequestManagers implements Closeable {
         this.coordinatorRequestManager = coordinatorRequestManager;
         this.commitRequestManager = Optional.empty();
         this.consumerHeartbeatRequestManager = Optional.empty();
+        this.streamsGroupHeartbeatRequestManager = Optional.empty();
         this.shareHeartbeatRequestManager = shareHeartbeatRequestManager;
         this.consumerMembershipManager = Optional.empty();
+        this.streamsMembershipManager = Optional.empty();
         this.shareMembershipManager = shareMembershipManager;
         this.offsetsRequestManager = null;
         this.topicMetadataRequestManager = null;
@@ -158,8 +168,10 @@ public class RequestManagers implements Closeable {
                                                      final 
Optional<ClientTelemetryReporter> clientTelemetryReporter,
                                                      final Metrics metrics,
                                                      final 
OffsetCommitCallbackInvoker offsetCommitCallbackInvoker,
-                                                     final MemberStateListener 
applicationThreadMemberStateListener
-                                                     ) {
+                                                     final MemberStateListener 
applicationThreadMemberStateListener,
+                                                     final 
Optional<StreamsRebalanceData> streamsRebalanceData,
+                                                     final 
Optional<StreamsRebalanceEventsProcessor> streamsRebalanceEventsProcessor
+    ) {
         return new CachedSupplier<>() {
             @Override
             protected RequestManagers create() {
@@ -187,26 +199,56 @@ public class RequestManagers implements Closeable {
                 ConsumerMembershipManager membershipManager = null;
                 CoordinatorRequestManager coordinator = null;
                 CommitRequestManager commitRequestManager = null;
+                StreamsGroupHeartbeatRequestManager 
streamsGroupHeartbeatRequestManager = null;
+                StreamsMembershipManager streamsMembershipManager = null;
 
                 if (groupRebalanceConfig != null && 
groupRebalanceConfig.groupId != null) {
                     Optional<String> serverAssignor = 
Optional.ofNullable(config.getString(ConsumerConfig.GROUP_REMOTE_ASSIGNOR_CONFIG));
                     coordinator = new CoordinatorRequestManager(
-                            logContext,
-                            retryBackoffMs,
-                            retryBackoffMaxMs,
-                            groupRebalanceConfig.groupId);
+                        logContext,
+                        retryBackoffMs,
+                        retryBackoffMaxMs,
+                        groupRebalanceConfig.groupId);
                     commitRequestManager = new CommitRequestManager(
+                        time,
+                        logContext,
+                        subscriptions,
+                        config,
+                        coordinator,
+                        offsetCommitCallbackInvoker,
+                        groupRebalanceConfig.groupId,
+                        groupRebalanceConfig.groupInstanceId,
+                        metrics,
+                        metadata);
+                    if (streamsRebalanceEventsProcessor.isPresent()) {
+                        streamsMembershipManager = new 
StreamsMembershipManager(
+                            groupRebalanceConfig.groupId,
+                            streamsRebalanceEventsProcessor.get(),
+                            streamsRebalanceData.get(),
+                            subscriptions,
+                            logContext,
                             time,
+                            metrics);
+                        
streamsMembershipManager.registerStateListener(commitRequestManager);
+                        
streamsMembershipManager.registerStateListener(applicationThreadMemberStateListener);
+
+                        if (clientTelemetryReporter.isPresent()) {
+                            clientTelemetryReporter.get()
+                                
.updateMetricsLabels(Map.of(ClientTelemetryProvider.GROUP_MEMBER_ID, 
streamsMembershipManager.memberId()));
+                        }
+
+                        streamsGroupHeartbeatRequestManager = new 
StreamsGroupHeartbeatRequestManager(
                             logContext,
-                            subscriptions,
+                            time,
                             config,
                             coordinator,
-                            offsetCommitCallbackInvoker,
-                            groupRebalanceConfig.groupId,
-                            groupRebalanceConfig.groupInstanceId,
+                            streamsMembershipManager,
+                            backgroundEventHandler,
                             metrics,
-                            metadata);
-                    membershipManager = new ConsumerMembershipManager(
+                            streamsRebalanceData.get()
+                        );
+                    } else {
+                        membershipManager = new ConsumerMembershipManager(
                             groupRebalanceConfig.groupId,
                             groupRebalanceConfig.groupInstanceId,
                             groupRebalanceConfig.rebalanceTimeoutMs,
@@ -220,17 +262,17 @@ public class RequestManagers implements Closeable {
                             metrics,
                             
config.getBoolean(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG));
 
-                    // Update the group member ID label in the client 
telemetry reporter.
-                    // According to KIP-1082, the consumer will generate the 
member ID as the incarnation ID of the process.
-                    // Therefore, we can update the group member ID during 
initialization.
-                    if (clientTelemetryReporter.isPresent()) {
-                        clientTelemetryReporter.get()
-                            
.updateMetricsLabels(Map.of(ClientTelemetryProvider.GROUP_MEMBER_ID, 
membershipManager.memberId()));
-                    }
+                        // Update the group member ID label in the client 
telemetry reporter.
+                        // According to KIP-1082, the consumer will generate 
the member ID as the incarnation ID of the process.
+                        // Therefore, we can update the group member ID during 
initialization.
+                        if (clientTelemetryReporter.isPresent()) {
+                            clientTelemetryReporter.get()
+                                
.updateMetricsLabels(Map.of(ClientTelemetryProvider.GROUP_MEMBER_ID, 
membershipManager.memberId()));
+                        }
 
-                    
membershipManager.registerStateListener(commitRequestManager);
-                    
membershipManager.registerStateListener(applicationThreadMemberStateListener);
-                    heartbeatRequestManager = new 
ConsumerHeartbeatRequestManager(
+                        
membershipManager.registerStateListener(commitRequestManager);
+                        
membershipManager.registerStateListener(applicationThreadMemberStateListener);
+                        heartbeatRequestManager = new 
ConsumerHeartbeatRequestManager(
                             logContext,
                             time,
                             config,
@@ -239,6 +281,7 @@ public class RequestManagers implements Closeable {
                             membershipManager,
                             backgroundEventHandler,
                             metrics);
+                    }
                 }
 
                 final OffsetsRequestManager listOffsets = new 
OffsetsRequestManager(subscriptions,
@@ -261,7 +304,9 @@ public class RequestManagers implements Closeable {
                         Optional.ofNullable(coordinator),
                         Optional.ofNullable(commitRequestManager),
                         Optional.ofNullable(heartbeatRequestManager),
-                        Optional.ofNullable(membershipManager)
+                        Optional.ofNullable(membershipManager),
+                        
Optional.ofNullable(streamsGroupHeartbeatRequestManager),
+                        Optional.ofNullable(streamsMembershipManager)
                 );
             }
         };
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManager.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManager.java
index 319a708e216..55741114b34 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManager.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsGroupHeartbeatRequestManager.java
@@ -74,13 +74,13 @@ public class StreamsGroupHeartbeatRequestManager implements 
RequestManager {
         // Fields of StreamsGroupHeartbeatRequest sent in the most recent 
request
         static class LastSentFields {
 
-            private StreamsRebalanceData.Assignment assignment = 
StreamsRebalanceData.Assignment.EMPTY;
+            private StreamsRebalanceData.Assignment assignment = null;
 
             LastSentFields() {
             }
 
             void reset() {
-                assignment = StreamsRebalanceData.Assignment.EMPTY;
+                assignment = null;
             }
         }
 
@@ -402,6 +402,10 @@ public class StreamsGroupHeartbeatRequestManager 
implements RequestManager {
         return EMPTY;
     }
 
+    public StreamsMembershipManager membershipManager() {
+        return membershipManager;
+    }
+
     /**
      * Returns the delay for which the application thread can safely wait 
before it should be responsive
      * to results from the request managers. For example, the subscription 
state can change when heartbeats
@@ -425,6 +429,17 @@ public class StreamsGroupHeartbeatRequestManager 
implements RequestManager {
         return Math.min(pollTimer.remainingMs() / 2, 
heartbeatRequestState.timeToNextHeartbeatMs(currentTimeMs));
     }
 
+    public void resetPollTimer(final long pollMs) {
+        pollTimer.update(pollMs);
+        if (pollTimer.isExpired()) {
+            logger.warn("Time between subsequent calls to poll() was longer 
than the configured " +
+                    "max.poll.interval.ms, exceeded approximately by {} ms. 
Member {} will rejoin the group now.",
+                pollTimer.isExpiredBy(), membershipManager.memberId());
+            membershipManager.maybeRejoinStaleMember();
+        }
+        pollTimer.reset(maxPollIntervalMs);
+    }
+
     /**
      * A heartbeat should be sent without waiting for the heartbeat interval 
to expire if:
      * - the member is leaving the group
@@ -508,7 +523,7 @@ public class StreamsGroupHeartbeatRequestManager implements 
RequestManager {
             String statusDetails = statuses.stream()
                 .map(status -> "(" + status.statusCode() + ") " + 
status.statusDetail())
                 .collect(Collectors.joining(", "));
-            logger.warn("Membership is in the following statuses: {}.", 
statusDetails);
+            logger.warn("Membership is in the following statuses: {}", 
statusDetails);
         }
 
         membershipManager.onHeartbeatSuccess(response);
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
index 43e804d84bd..6157b66cf16 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceData.java
@@ -75,10 +75,7 @@ public class StreamsRebalanceData {
 
         @Override
         public String toString() {
-            return "TaskId{" +
-                "subtopologyId=" + subtopologyId +
-                ", partitionId=" + partitionId +
-                '}';
+            return subtopologyId + "_" + partitionId;
         }
     }
 
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
index db91e8fcece..30bc7dcfb07 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessor.java
@@ -45,20 +45,17 @@ import java.util.concurrent.LinkedBlockingQueue;
 public class StreamsRebalanceEventsProcessor {
 
     private final BlockingQueue<BackgroundEvent> onCallbackRequests = new 
LinkedBlockingQueue<>();
-    private ApplicationEventHandler applicationEventHandler = null;
-    private final StreamsGroupRebalanceCallbacks rebalanceCallbacks;
     private final StreamsRebalanceData streamsRebalanceData;
+    private ApplicationEventHandler applicationEventHandler;
+    private StreamsGroupRebalanceCallbacks rebalanceCallbacks;
 
     /**
      * Constructs the Streams rebalance processor.
      *
      * @param streamsRebalanceData
-     * @param rebalanceCallbacks
      */
-    public StreamsRebalanceEventsProcessor(StreamsRebalanceData 
streamsRebalanceData,
-                                           StreamsGroupRebalanceCallbacks 
rebalanceCallbacks) {
+    public StreamsRebalanceEventsProcessor(StreamsRebalanceData 
streamsRebalanceData) {
         this.streamsRebalanceData = streamsRebalanceData;
-        this.rebalanceCallbacks = rebalanceCallbacks;
     }
 
     /**
@@ -107,6 +104,10 @@ public class StreamsRebalanceEventsProcessor {
         this.applicationEventHandler = applicationEventHandler;
     }
 
+    public void setRebalanceCallbacks(final StreamsGroupRebalanceCallbacks 
rebalanceCallbacks) {
+        this.rebalanceCallbacks = rebalanceCallbacks;
+    }
+
     private void process(final BackgroundEvent event) {
         switch (event.type()) {
             case ERROR:
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessor.java
 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessor.java
index 42dd1711b96..1ca51cca62e 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessor.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessor.java
@@ -72,7 +72,7 @@ public class ApplicationEventProcessor implements 
EventProcessor<ApplicationEven
         this.metadataVersionSnapshot = metadata.updateVersion();
     }
 
-    @SuppressWarnings({"CyclomaticComplexity"})
+    @SuppressWarnings({"CyclomaticComplexity", "JavaNCSSCheck"})
     @Override
     public void process(ApplicationEvent event) {
         switch (event.type()) {
@@ -200,6 +200,18 @@ public class ApplicationEventProcessor implements 
EventProcessor<ApplicationEven
                 process((CurrentLagEvent) event);
                 return;
 
+            case STREAMS_ON_TASKS_REVOKED_CALLBACK_COMPLETED:
+                process((StreamsOnTasksRevokedCallbackCompletedEvent) event);
+                return;
+
+            case STREAMS_ON_TASKS_ASSIGNED_CALLBACK_COMPLETED:
+                process((StreamsOnTasksAssignedCallbackCompletedEvent) event);
+                return;
+
+            case STREAMS_ON_ALL_TASKS_LOST_CALLBACK_COMPLETED:
+                process((StreamsOnAllTasksLostCallbackCompletedEvent) event);
+                return;
+
             default:
                 log.warn("Application event type {} was not expected", 
event.type());
         }
@@ -220,6 +232,10 @@ public class ApplicationEventProcessor implements 
EventProcessor<ApplicationEven
                 hrm.membershipManager().onConsumerPoll();
                 hrm.resetPollTimer(event.pollTimeMs());
             });
+            requestManagers.streamsGroupHeartbeatRequestManager.ifPresent(hrm 
-> {
+                hrm.membershipManager().onConsumerPoll();
+                hrm.resetPollTimer(event.pollTimeMs());
+            });
         } else {
             // safe to unblock - no auto-commit risk here:
             // 1. commitRequestManager is not present
@@ -320,22 +336,32 @@ public class ApplicationEventProcessor implements 
EventProcessor<ApplicationEven
      * it is already a member on the next poll.
      */
     private void process(final TopicSubscriptionChangeEvent event) {
-        if (requestManagers.consumerHeartbeatRequestManager.isEmpty()) {
+        if (requestManagers.consumerHeartbeatRequestManager.isPresent()) {
+            try {
+                if (subscriptions.subscribe(event.topics(), event.listener())) 
{
+                    this.metadataVersionSnapshot = 
metadata.requestUpdateForNewTopics();
+                }
+                // Join the group if not already part of it, or just send the 
new subscription to the broker on the next poll.
+                
requestManagers.consumerHeartbeatRequestManager.get().membershipManager().onSubscriptionUpdated();
+                event.future().complete(null);
+            } catch (Exception e) {
+                event.future().completeExceptionally(e);
+            }
+        } else if 
(requestManagers.streamsGroupHeartbeatRequestManager.isPresent()) {
+            try {
+                if (subscriptions.subscribe(event.topics(), event.listener())) 
{
+                    this.metadataVersionSnapshot = 
metadata.requestUpdateForNewTopics();
+                }
+                
requestManagers.streamsMembershipManager.get().onSubscriptionUpdated();
+                event.future().complete(null);
+            } catch (Exception e) {
+                event.future().completeExceptionally(e);
+            }
+        } else {
             log.warn("Group membership manager not present when processing a 
subscribe event");
             event.future().complete(null);
-            return;
         }
 
-        try {
-            if (subscriptions.subscribe(event.topics(), event.listener()))
-                this.metadataVersionSnapshot = 
metadata.requestUpdateForNewTopics();
-
-            // Join the group if not already part of it, or just send the new 
subscription to the broker on the next poll.
-            
requestManagers.consumerHeartbeatRequestManager.get().membershipManager().onSubscriptionUpdated();
-            event.future().complete(null);
-        } catch (Exception e) {
-            event.future().completeExceptionally(e);
-        }
     }
 
     /**
@@ -405,6 +431,9 @@ public class ApplicationEventProcessor implements 
EventProcessor<ApplicationEven
         if (requestManagers.consumerHeartbeatRequestManager.isPresent()) {
             CompletableFuture<Void> future = 
requestManagers.consumerHeartbeatRequestManager.get().membershipManager().leaveGroup();
             future.whenComplete(complete(event.future()));
+        } else if 
(requestManagers.streamsGroupHeartbeatRequestManager.isPresent()) {
+            CompletableFuture<Void> future = 
requestManagers.streamsGroupHeartbeatRequestManager.get().membershipManager().leaveGroup();
+            future.whenComplete(complete(event.future()));
         } else {
             // If the consumer is not using the group management capabilities, 
we still need to clear all assignments it may have.
             subscriptions.unsubscribe();
@@ -463,12 +492,15 @@ public class ApplicationEventProcessor implements 
EventProcessor<ApplicationEven
     }
 
     private void process(final LeaveGroupOnCloseEvent event) {
-        if (requestManagers.consumerMembershipManager.isEmpty())
-            return;
-
-        log.debug("Signal the ConsumerMembershipManager to leave the consumer 
group since the consumer is closing");
-        CompletableFuture<Void> future = 
requestManagers.consumerMembershipManager.get().leaveGroupOnClose(event.membershipOperation());
-        future.whenComplete(complete(event.future()));
+        if (requestManagers.consumerMembershipManager.isPresent()) {
+            log.debug("Signal the ConsumerMembershipManager to leave the 
consumer group since the consumer is closing");
+            CompletableFuture<Void> future = 
requestManagers.consumerMembershipManager.get().leaveGroupOnClose(event.membershipOperation());
+            future.whenComplete(complete(event.future()));
+        } else if (requestManagers.streamsMembershipManager.isPresent()) {
+            log.debug("Signal the StreamsMembershipManager to leave the 
Streams group since the member is closing");
+            CompletableFuture<Void> future = 
requestManagers.streamsMembershipManager.get().leaveGroupOnClose();
+            future.whenComplete(complete(event.future()));
+        }
     }
 
     private void process(@SuppressWarnings("unused") final 
StopFindCoordinatorOnCloseEvent event) {
@@ -667,6 +699,33 @@ public class ApplicationEventProcessor implements 
EventProcessor<ApplicationEven
         }
     }
 
+    private void process(final StreamsOnTasksRevokedCallbackCompletedEvent 
event) {
+        if (requestManagers.streamsMembershipManager.isEmpty()) {
+            log.warn("An internal error occurred; the Streams membership 
manager was not present, so the notification " +
+                "of the onTasksRevoked callback execution could not be sent");
+            return;
+        }
+        
requestManagers.streamsMembershipManager.get().onTasksRevokedCallbackCompleted(event);
+    }
+
+    private void process(final StreamsOnTasksAssignedCallbackCompletedEvent 
event) {
+        if (requestManagers.streamsMembershipManager.isEmpty()) {
+            log.warn("An internal error occurred; the Streams membership 
manager was not present, so the notification " +
+                "of the onTasksAssigned callback execution could not be sent");
+            return;
+        }
+        
requestManagers.streamsMembershipManager.get().onTasksAssignedCallbackCompleted(event);
+    }
+
+    private void process(final StreamsOnAllTasksLostCallbackCompletedEvent 
event) {
+        if (requestManagers.streamsMembershipManager.isEmpty()) {
+            log.warn("An internal error occurred; the Streams membership 
manager was not present, so the notification " +
+                "of the onAllTasksLost callback execution could not be sent");
+            return;
+        }
+        
requestManagers.streamsMembershipManager.get().onAllTasksLostCallbackCompleted(event);
+    }
+
     private <T> BiConsumer<? super T, ? super Throwable> complete(final 
CompletableFuture<T> b) {
         return (value, exception) -> {
             if (exception != null)
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java
index e501f3aeea3..df091bfccb6 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java
@@ -219,7 +219,9 @@ public class AsyncKafkaConsumerTest {
             a -> backgroundEventReaper,
             (a, b, c, d, e, f, g) -> fetchCollector,
             (a, b, c, d) -> metadata,
-            backgroundEventQueue
+            backgroundEventQueue,
+            Optional.empty(),
+            Optional.empty()
         );
     }
 
@@ -233,7 +235,9 @@ public class AsyncKafkaConsumerTest {
             a -> backgroundEventReaper,
             (a, b, c, d, e, f, g) -> fetchCollector,
             (a, b, c, d) -> metadata,
-            backgroundEventQueue
+            backgroundEventQueue,
+            Optional.empty(),
+            Optional.empty()
         );
     }
 
@@ -1315,7 +1319,9 @@ public class AsyncKafkaConsumerTest {
             any(),
             any(),
             any(),
-            applicationThreadMemberStateListener.capture()
+            applicationThreadMemberStateListener.capture(),
+            any(),
+            any()
         ));
         return applicationThreadMemberStateListener.getValue();
     }
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/RequestManagersTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/RequestManagersTest.java
index 405ecabcf16..ef3ba79dbf5 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/RequestManagersTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/RequestManagersTest.java
@@ -62,7 +62,9 @@ public class RequestManagersTest {
             Optional.empty(),
             new Metrics(),
             mock(OffsetCommitCallbackInvoker.class),
-            listener
+            listener,
+            Optional.empty(),
+            Optional.empty()
         ).get();
         requestManagers.consumerMembershipManager.ifPresent(
             membershipManager -> 
assertTrue(membershipManager.stateListeners().contains(listener))
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java
index 6c7b164826b..5bc988cf9c7 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceEventsProcessorTest.java
@@ -74,7 +74,8 @@ public class StreamsRebalanceEventsProcessorTest {
     @Test
     public void shouldInvokeOnTasksAssignedCallback() {
         final StreamsRebalanceEventsProcessor rebalanceEventsProcessor =
-            new StreamsRebalanceEventsProcessor(rebalanceData, 
rebalanceCallbacks);
+            new StreamsRebalanceEventsProcessor(rebalanceData);
+        rebalanceEventsProcessor.setRebalanceCallbacks(rebalanceCallbacks);
         
rebalanceEventsProcessor.setApplicationEventHandler(applicationEventHandler);
         final Set<StreamsRebalanceData.TaskId> activeTasks = Set.of(
             new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0),
@@ -109,7 +110,8 @@ public class StreamsRebalanceEventsProcessorTest {
     @Test
     public void 
shouldReThrowErrorFromOnTasksAssignedCallbackAndPassErrorToBackground() {
         final StreamsRebalanceEventsProcessor rebalanceEventsProcessor =
-            new StreamsRebalanceEventsProcessor(rebalanceData, 
rebalanceCallbacks);
+            new StreamsRebalanceEventsProcessor(rebalanceData);
+        rebalanceEventsProcessor.setRebalanceCallbacks(rebalanceCallbacks);
         
rebalanceEventsProcessor.setApplicationEventHandler(applicationEventHandler);
         final Set<StreamsRebalanceData.TaskId> activeTasks = Set.of(
             new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0),
@@ -148,7 +150,8 @@ public class StreamsRebalanceEventsProcessorTest {
     @Test
     public void shouldInvokeOnTasksRevokedCallback() {
         final StreamsRebalanceEventsProcessor rebalanceEventsProcessor =
-            new StreamsRebalanceEventsProcessor(rebalanceData, 
rebalanceCallbacks);
+            new StreamsRebalanceEventsProcessor(rebalanceData);
+        rebalanceEventsProcessor.setRebalanceCallbacks(rebalanceCallbacks);
         
rebalanceEventsProcessor.setApplicationEventHandler(applicationEventHandler);
         final Set<StreamsRebalanceData.TaskId> activeTasks = Set.of(
             new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0),
@@ -172,7 +175,8 @@ public class StreamsRebalanceEventsProcessorTest {
     @Test
     public void 
shouldReThrowErrorFromOnTasksRevokedCallbackAndPassErrorToBackground() {
         final StreamsRebalanceEventsProcessor rebalanceEventsProcessor =
-            new StreamsRebalanceEventsProcessor(rebalanceData, 
rebalanceCallbacks);
+            new StreamsRebalanceEventsProcessor(rebalanceData);
+        rebalanceEventsProcessor.setRebalanceCallbacks(rebalanceCallbacks);
         
rebalanceEventsProcessor.setApplicationEventHandler(applicationEventHandler);
         final Set<StreamsRebalanceData.TaskId> activeTasks = Set.of(
             new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0),
@@ -199,7 +203,8 @@ public class StreamsRebalanceEventsProcessorTest {
     @Test
     public void shouldInvokeOnAllTasksLostCallback() {
         final StreamsRebalanceEventsProcessor rebalanceEventsProcessor =
-            new StreamsRebalanceEventsProcessor(rebalanceData, 
rebalanceCallbacks);
+            new StreamsRebalanceEventsProcessor(rebalanceData);
+        rebalanceEventsProcessor.setRebalanceCallbacks(rebalanceCallbacks);
         
rebalanceEventsProcessor.setApplicationEventHandler(applicationEventHandler);
         final Set<StreamsRebalanceData.TaskId> activeTasks = Set.of(
             new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0),
@@ -238,7 +243,8 @@ public class StreamsRebalanceEventsProcessorTest {
     @Test
     public void 
shouldReThrowErrorFromOnAllTasksLostCallbackAndPassErrorToBackground() {
         final StreamsRebalanceEventsProcessor rebalanceEventsProcessor =
-            new StreamsRebalanceEventsProcessor(rebalanceData, 
rebalanceCallbacks);
+            new StreamsRebalanceEventsProcessor(rebalanceData);
+        rebalanceEventsProcessor.setRebalanceCallbacks(rebalanceCallbacks);
         
rebalanceEventsProcessor.setApplicationEventHandler(applicationEventHandler);
         final Set<StreamsRebalanceData.TaskId> activeTasks = Set.of(
             new StreamsRebalanceData.TaskId(SUBTOPOLOGY_0, 0),
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessorTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessorTest.java
index 75d696806a4..9cd306a9be1 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessorTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessorTest.java
@@ -97,7 +97,10 @@ public class ApplicationEventProcessorTest {
                 withGroupId ? 
Optional.of(mock(CoordinatorRequestManager.class)) : Optional.empty(),
                 withGroupId ? Optional.of(commitRequestManager) : 
Optional.empty(),
                 withGroupId ? Optional.of(heartbeatRequestManager) : 
Optional.empty(),
-                withGroupId ? Optional.of(membershipManager) : 
Optional.empty());
+                withGroupId ? Optional.of(membershipManager) : 
Optional.empty(),
+                Optional.empty(),
+                Optional.empty()
+        );
         processor = new ApplicationEventProcessor(
                 new LogContext(),
                 requestManagers,
diff --git 
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java
 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java
index aaf2af7b1cc..01e899ecc21 100644
--- 
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java
+++ 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java
@@ -18,6 +18,9 @@ package org.apache.kafka.streams.integration;
 
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.common.utils.Exit;
+import org.apache.kafka.coordinator.group.GroupCoordinatorConfig;
+import org.apache.kafka.server.config.ServerConfigs;
+import org.apache.kafka.streams.GroupProtocol;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsConfig.InternalConfig;
 import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
@@ -27,7 +30,9 @@ import org.apache.kafka.streams.tests.SmokeTestDriver;
 
 import org.junit.jupiter.api.AfterAll;
 import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Tag;
+import org.junit.jupiter.api.TestInfo;
 import org.junit.jupiter.api.Timeout;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.CsvSource;
@@ -35,29 +40,40 @@ import org.junit.jupiter.params.provider.CsvSource;
 import java.io.IOException;
 import java.time.Duration;
 import java.util.ArrayList;
+import java.util.Locale;
 import java.util.Map;
 import java.util.Properties;
 import java.util.Set;
 
 import static org.apache.kafka.streams.tests.SmokeTestDriver.generate;
 import static org.apache.kafka.streams.tests.SmokeTestDriver.verify;
+import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 @Timeout(600)
 @Tag("integration")
 public class SmokeTestDriverIntegrationTest {
-    public static final EmbeddedKafkaCluster CLUSTER = new 
EmbeddedKafkaCluster(3);
+    public static EmbeddedKafkaCluster cluster;
+    public TestInfo testInfo;
 
     @BeforeAll
     public static void startCluster() throws IOException {
-        CLUSTER.start();
+        final Properties props = new Properties();
+        
props.setProperty(GroupCoordinatorConfig.GROUP_COORDINATOR_REBALANCE_PROTOCOLS_CONFIG,
 "classic,consumer,streams");
+        props.setProperty(ServerConfigs.UNSTABLE_API_VERSIONS_ENABLE_CONFIG, 
"true");
+        cluster = new EmbeddedKafkaCluster(3, props);
+        cluster.start();
     }
 
     @AfterAll
     public static void closeCluster() {
-        CLUSTER.stop();
+        cluster.stop();
     }
 
+    @BeforeEach
+    public void setUp(final TestInfo testInfo) {
+        this.testInfo = testInfo;
+    }
 
     private static class Driver extends Thread {
         private final String bootstrapServers;
@@ -99,8 +115,17 @@ public class SmokeTestDriverIntegrationTest {
     // We set 2 timeout condition to fail the test before passing the 
verification:
     // (1) 10 min timeout, (2) 30 tries of polling without getting any data
     @ParameterizedTest
-    @CsvSource({"false, false", "true, false"})
-    public void shouldWorkWithRebalance(final boolean stateUpdaterEnabled, 
final boolean processingThreadsEnabled) throws InterruptedException {
+    @CsvSource({
+        "false, false, true",
+        "true, false, true",
+        "false, false, false",
+        "true, false, false",
+    })
+    public void shouldWorkWithRebalance(
+        final boolean stateUpdaterEnabled,
+        final boolean processingThreadsEnabled,
+        final boolean streamsProtocolEnabled
+    ) throws InterruptedException {
         Exit.setExitProcedure((statusCode, message) -> {
             throw new AssertionError("Test called exit(). code:" + statusCode 
+ " message:" + message);
         });
@@ -110,9 +135,9 @@ public class SmokeTestDriverIntegrationTest {
         int numClientsCreated = 0;
         final ArrayList<SmokeTestClient> clients = new ArrayList<>();
 
-        IntegrationTestUtils.cleanStateBeforeTest(CLUSTER, 
SmokeTestDriver.topics());
+        IntegrationTestUtils.cleanStateBeforeTest(cluster, 
SmokeTestDriver.topics());
 
-        final String bootstrapServers = CLUSTER.bootstrapServers();
+        final String bootstrapServers = cluster.bootstrapServers();
         final Driver driver = new Driver(bootstrapServers, 10, 1000);
         driver.start();
         System.out.println("started driver");
@@ -120,10 +145,15 @@ public class SmokeTestDriverIntegrationTest {
 
         final Properties props = new Properties();
         props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers);
+        props.put(StreamsConfig.APPLICATION_ID_CONFIG, 
safeUniqueTestName(testInfo));
         props.put(InternalConfig.STATE_UPDATER_ENABLED, stateUpdaterEnabled);
         props.put(InternalConfig.PROCESSING_THREADS_ENABLED, 
processingThreadsEnabled);
         // decrease the session timeout so that we can trigger the rebalance 
soon after old client left closed
         props.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10000);
+        props.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 500);
+        if (streamsProtocolEnabled) {
+            props.put(StreamsConfig.GROUP_PROTOCOL_CONFIG, 
GroupProtocol.STREAMS.name().toLowerCase(Locale.getDefault()));
+        }
 
         // cycle out Streams instances as long as the test is running.
         while (driver.isAlive()) {
diff --git a/streams/src/main/java/org/apache/kafka/streams/GroupProtocol.java 
b/streams/src/main/java/org/apache/kafka/streams/GroupProtocol.java
new file mode 100644
index 00000000000..146a5e6e9de
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/GroupProtocol.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams;
+
+import java.util.Locale;
+
+public enum GroupProtocol {
+    /** Classic group protocol.  */
+    CLASSIC("CLASSIC"),
+
+    /** Streams group protocol */
+    STREAMS("STREAMS");
+
+    /**
+     * String representation of the group protocol.
+     */
+    public final String name;
+
+    GroupProtocol(final String name) {
+        this.name = name;
+    }
+
+    /**
+     * Case-insensitive group protocol lookup by string name.
+     */
+    public static GroupProtocol of(final String name) {
+        return GroupProtocol.valueOf(name.toUpperCase(Locale.ROOT));
+    }
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java 
b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
index b98806590a1..e19cb03b207 100644
--- a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
+++ b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
@@ -68,6 +68,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Properties;
@@ -608,6 +609,16 @@ public class StreamsConfig extends AbstractConfig {
         " topics (e.g., changelog and repartition topics) and their associated 
state stores." +
         " When enabled, the application will refuse to start if any internal 
resource has an auto-generated name.";
 
+    /**
+     * <code>group.protocol</code>
+     */
+    public static final String GROUP_PROTOCOL_CONFIG = "group.protocol";
+    public static final String DEFAULT_GROUP_PROTOCOL = 
GroupProtocol.CLASSIC.name().toLowerCase(
+        Locale.ROOT);
+    private static final String GROUP_PROTOCOL_DOC = "The group protocol 
streams should use. We currently " +
+        "support \"classic\" or \"streams\". If \"streams\" is specified, then 
the streams rebalance protocol will be " +
+        "used. Otherwise, the classic group protocol will be used.";
+
     /** {@code log.summary.interval.ms} */
     public static final String LOG_SUMMARY_INTERVAL_MS_CONFIG = 
"log.summary.interval.ms";
     private static final String LOG_SUMMARY_INTERVAL_MS_DOC = "The output 
interval in milliseconds for logging summary information.\n" +
@@ -1042,6 +1053,12 @@ public class StreamsConfig extends AbstractConfig {
                         TOPOLOGY_OPTIMIZATION_CONFIGS::toString),
                     Importance.MEDIUM,
                     TOPOLOGY_OPTIMIZATION_DOC)
+            .define(GROUP_PROTOCOL_CONFIG,
+                    Type.STRING,
+                    DEFAULT_GROUP_PROTOCOL,
+                    
ConfigDef.CaseInsensitiveValidString.in(Utils.enumOptions(GroupProtocol.class)),
+                    Importance.MEDIUM,
+                    GROUP_PROTOCOL_DOC)
 
             // LOW
 
@@ -1505,6 +1522,11 @@ public class StreamsConfig extends AbstractConfig {
         }
         
verifyTopologyOptimizationConfigs(getString(TOPOLOGY_OPTIMIZATION_CONFIG));
         verifyClientTelemetryConfigs();
+
+        if (doLog && 
getString(GROUP_PROTOCOL_CONFIG).equals(GroupProtocol.STREAMS.name().toLowerCase(Locale.ROOT)))
 {
+            log.warn("The streams rebalance protocol is still in development 
and should not be used in production. "
+                + "Please set group.protocol=classic (default) in all 
production use cases.");
+        }
     }
 
     private void verifyEOSTransactionTimeoutCompatibility() {
@@ -1627,6 +1649,8 @@ public class StreamsConfig extends AbstractConfig {
     private Map<String, Object> getCommonConsumerConfigs() {
         final Map<String, Object> clientProvidedProps = 
getClientPropsWithPrefix(CONSUMER_PREFIX, ConsumerConfig.configNames());
 
+        clientProvidedProps.remove(GROUP_PROTOCOL_CONFIG);
+
         checkIfUnexpectedUserSpecifiedConsumerConfig(clientProvidedProps, 
NON_CONFIGURABLE_CONSUMER_DEFAULT_CONFIGS);
         checkIfUnexpectedUserSpecifiedConsumerConfig(clientProvidedProps, 
NON_CONFIGURABLE_CONSUMER_EOS_CONFIGS);
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsGroupRebalanceCallbacks.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsGroupRebalanceCallbacks.java
new file mode 100644
index 00000000000..37e61427ccd
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsGroupRebalanceCallbacks.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import 
org.apache.kafka.clients.consumer.internals.StreamsGroupRebalanceCallbacks;
+import org.apache.kafka.clients.consumer.internals.StreamsRebalanceData;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.processor.TaskId;
+
+import org.slf4j.Logger;
+
+import java.util.Collection;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class DefaultStreamsGroupRebalanceCallbacks implements 
StreamsGroupRebalanceCallbacks {
+
+    private final Logger log;
+    private final Time time;
+    private final StreamsRebalanceData streamsRebalanceData;
+    private final TaskManager taskManager;
+    private final StreamThread streamThread;
+
+    public DefaultStreamsGroupRebalanceCallbacks(final Logger log,
+                                                 final Time time,
+                                                 final StreamsRebalanceData 
streamsRebalanceData,
+                                                 final StreamThread 
streamThread,
+                                                 final TaskManager 
taskManager) {
+        this.log = log;
+        this.time = time;
+        this.streamsRebalanceData = streamsRebalanceData;
+        this.streamThread = streamThread;
+        this.taskManager = taskManager;
+    }
+
+    @Override
+    public Optional<Exception> onTasksRevoked(final 
Set<StreamsRebalanceData.TaskId> tasks) {
+        try {
+            final Map<TaskId, Set<TopicPartition>> 
activeTasksToRevokeWithPartitions =
+                pairWithTopicPartitions(tasks.stream());
+            final Set<TopicPartition> partitionsToRevoke = 
activeTasksToRevokeWithPartitions.values().stream()
+                .flatMap(Collection::stream)
+                .collect(Collectors.toSet());
+
+            final long start = time.milliseconds();
+            try {
+                log.info("Revoking active tasks {}.", tasks);
+                taskManager.handleRevocation(partitionsToRevoke);
+            } finally {
+                log.info("partition revocation took {} ms.", 
time.milliseconds() - start);
+            }
+            if (streamThread.state() != StreamThread.State.PENDING_SHUTDOWN) {
+                streamThread.setState(StreamThread.State.PARTITIONS_REVOKED);
+            }
+        } catch (final Exception exception) {
+            return Optional.of(exception);
+        }
+        return Optional.empty();
+    }
+
+    @Override
+    public Optional<Exception> onTasksAssigned(final 
StreamsRebalanceData.Assignment assignment) {
+        try {
+            final Map<TaskId, Set<TopicPartition>> activeTasksWithPartitions =
+                pairWithTopicPartitions(assignment.activeTasks().stream());
+            final Map<TaskId, Set<TopicPartition>> standbyTasksWithPartitions =
+                
pairWithTopicPartitions(Stream.concat(assignment.standbyTasks().stream(), 
assignment.warmupTasks().stream()));
+
+            log.info("Processing new assignment {} from Streams Rebalance 
Protocol", assignment);
+
+            taskManager.handleAssignment(activeTasksWithPartitions, 
standbyTasksWithPartitions);
+            streamThread.setState(StreamThread.State.PARTITIONS_ASSIGNED);
+            taskManager.handleRebalanceComplete();
+        } catch (final Exception exception) {
+            return Optional.of(exception);
+        }
+        return Optional.empty();
+    }
+
+    @Override
+    public Optional<Exception> onAllTasksLost() {
+        try {
+            taskManager.handleLostAll();
+        } catch (final Exception exception) {
+            return Optional.of(exception);
+        }
+        return Optional.empty();
+    }
+
+    private Map<TaskId, Set<TopicPartition>> pairWithTopicPartitions(final 
Stream<StreamsRebalanceData.TaskId> taskIdStream) {
+        return taskIdStream
+            .collect(Collectors.toMap(
+                this::toTaskId,
+                task -> toTopicPartitions(task, 
streamsRebalanceData.subtopologies().get(task.subtopologyId()))
+            ));
+    }
+
+    private TaskId toTaskId(final StreamsRebalanceData.TaskId task) {
+        return new TaskId(Integer.parseInt(task.subtopologyId()), 
task.partitionId());
+    }
+
+    private Set<TopicPartition> toTopicPartitions(final 
StreamsRebalanceData.TaskId task,
+                                                  final 
StreamsRebalanceData.Subtopology subTopology) {
+        return
+            Stream.concat(
+                    subTopology.sourceTopics().stream(),
+                    subTopology.repartitionSourceTopics().keySet().stream()
+                )
+                .map(t -> new TopicPartition(t, task.partitionId()))
+                .collect(Collectors.toSet());
+    }
+}
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 74e7a27ceb9..a96d968ca27 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -25,11 +25,16 @@ import org.apache.kafka.clients.consumer.ConsumerRecords;
 import org.apache.kafka.clients.consumer.InvalidOffsetException;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.consumer.OffsetAndTimestamp;
+import org.apache.kafka.clients.consumer.internals.AsyncKafkaConsumer;
 import org.apache.kafka.clients.consumer.internals.AutoOffsetResetStrategy;
+import org.apache.kafka.clients.consumer.internals.StreamsRebalanceData;
+import 
org.apache.kafka.clients.consumer.internals.StreamsRebalanceEventsProcessor;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.KafkaFuture;
 import org.apache.kafka.common.Metric;
 import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.errors.TimeoutException;
@@ -40,6 +45,7 @@ import 
org.apache.kafka.common.serialization.ByteArrayDeserializer;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.streams.GroupProtocol;
 import org.apache.kafka.streams.KafkaClientSupplier;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsConfig.InternalConfig;
@@ -60,12 +66,14 @@ import 
org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics;
 import org.apache.kafka.streams.processor.internals.tasks.DefaultTaskManager;
 import 
org.apache.kafka.streams.processor.internals.tasks.DefaultTaskManager.DefaultTaskExecutorCreator;
+import org.apache.kafka.streams.state.HostInfo;
 import org.apache.kafka.streams.state.internals.ThreadCache;
 
 import org.slf4j.Logger;
 
 import java.time.Duration;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -86,6 +94,7 @@ import static 
org.apache.kafka.streams.processor.internals.ClientUtils.adminClie
 import static 
org.apache.kafka.streams.processor.internals.ClientUtils.consumerClientId;
 import static 
org.apache.kafka.streams.processor.internals.ClientUtils.restoreConsumerClientId;
 
+@SuppressWarnings("ClassDataAbstractionCoupling")
 public class StreamThread extends Thread implements ProcessingThread {
 
     private static final String THREAD_ID_SUBSTRING = "-StreamThread-";
@@ -343,6 +352,10 @@ public class StreamThread extends Thread implements 
ProcessingThread {
     // handler for, eg MissingSourceTopicException with named topologies
     private final Queue<StreamsException> nonFatalExceptionsToHandle;
 
+    private final Optional<StreamsRebalanceData> streamsRebalanceData;
+    private final Optional<StreamsRebalanceEventsProcessor> 
streamsRebalanceEventsProcessor;
+    private final StreamsMetadataState streamsMetadataState;
+
     // These are used to signal from outside the stream thread, but the 
variables themselves are internal to the thread
     private final AtomicLong cacheResizeSize = new AtomicLong(-1L);
     private final AtomicBoolean leaveGroupRequested = new AtomicBoolean(false);
@@ -478,18 +491,19 @@ public class StreamThread extends Thread implements 
ProcessingThread {
             consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, 
"none");
         }
 
-        final Consumer<byte[], byte[]> mainConsumer = 
clientSupplier.getConsumer(consumerConfigs);
-        taskManager.setMainConsumer(mainConsumer);
-        referenceContainer.mainConsumer = mainConsumer;
+        final MainConsumerSetup mainConsumerSetup = 
setupMainConsumer(topologyMetadata, config, clientSupplier, processId, log, 
consumerConfigs);
+
+        taskManager.setMainConsumer(mainConsumerSetup.mainConsumer);
+        referenceContainer.mainConsumer = mainConsumerSetup.mainConsumer;
 
-        final StreamsThreadMetricsDelegatingReporter reporter = new 
StreamsThreadMetricsDelegatingReporter(mainConsumer, threadId, 
Optional.of(stateUpdaterId));
+        final StreamsThreadMetricsDelegatingReporter reporter = new 
StreamsThreadMetricsDelegatingReporter(mainConsumerSetup.mainConsumer, 
threadId, Optional.of(stateUpdaterId));
         streamsMetrics.metricsRegistry().addReporter(reporter);
 
         final StreamThread streamThread = new StreamThread(
             time,
             config,
             adminClient,
-            mainConsumer,
+            mainConsumerSetup.mainConsumer,
             restoreConsumer,
             changelogReader,
             originalReset,
@@ -505,12 +519,73 @@ public class StreamThread extends Thread implements 
ProcessingThread {
             referenceContainer.nonFatalExceptionsToHandle,
             shutdownErrorHook,
             streamsUncaughtExceptionHandler,
-            cache::resize
+            cache::resize,
+            mainConsumerSetup.streamsRebalanceData,
+            mainConsumerSetup.streamsRebalanceEventsProcessor,
+            streamsMetadataState
         );
 
         return streamThread.updateThreadMetadata(adminClientId(clientId));
     }
 
+    private static MainConsumerSetup setupMainConsumer(final TopologyMetadata 
topologyMetadata,
+                                                       final StreamsConfig 
config,
+                                                       final 
KafkaClientSupplier clientSupplier,
+                                                       final UUID processId,
+                                                       final Logger log,
+                                                       final Map<String, 
Object> consumerConfigs) {
+        if 
(config.getString(StreamsConfig.GROUP_PROTOCOL_CONFIG).equalsIgnoreCase(GroupProtocol.STREAMS.name))
 {
+            if (topologyMetadata.hasNamedTopologies()) {
+                throw new IllegalStateException("Named topologies and the 
CONSUMER protocol cannot be used at the same time.");
+            }
+            log.info("Streams rebalance protocol enabled");
+
+            final Optional<StreamsRebalanceData> streamsRebalanceData = 
Optional.of(
+                initStreamsRebalanceData(
+                    processId,
+                    config,
+                    
parseHostInfo(config.getString(StreamsConfig.APPLICATION_SERVER_CONFIG)),
+                    topologyMetadata
+                )
+            );
+            final Optional<StreamsRebalanceEventsProcessor> 
streamsRebalanceEventsProcessor =
+                Optional.of(new 
StreamsRebalanceEventsProcessor(streamsRebalanceData.get()));
+            final ByteArrayDeserializer keyDeserializer = new 
ByteArrayDeserializer();
+            final ByteArrayDeserializer valueDeserializer = new 
ByteArrayDeserializer();
+            return new MainConsumerSetup(
+                new AsyncKafkaConsumer<>(
+                    new 
ConsumerConfig(ConsumerConfig.appendDeserializerToConfig(consumerConfigs, 
keyDeserializer, valueDeserializer)),
+                    keyDeserializer,
+                    valueDeserializer,
+                    streamsRebalanceData,
+                    streamsRebalanceEventsProcessor
+                ),
+                streamsRebalanceData,
+                streamsRebalanceEventsProcessor
+            );
+        } else {
+            return  new MainConsumerSetup(
+                clientSupplier.getConsumer(consumerConfigs),
+                Optional.empty(),
+                Optional.empty()
+            );
+        }
+    }
+
+    private static class MainConsumerSetup {
+        public final Consumer<byte[], byte[]> mainConsumer;
+        public final Optional<StreamsRebalanceData> streamsRebalanceData;
+        public final Optional<StreamsRebalanceEventsProcessor> 
streamsRebalanceEventsProcessor;
+
+        public MainConsumerSetup(final Consumer<byte[], byte[]> mainConsumer,
+                                 final Optional<StreamsRebalanceData> 
streamsRebalanceData,
+                                 final 
Optional<StreamsRebalanceEventsProcessor> streamsRebalanceEventsProcessor) {
+            this.mainConsumer = mainConsumer;
+            this.streamsRebalanceData = streamsRebalanceData;
+            this.streamsRebalanceEventsProcessor = 
streamsRebalanceEventsProcessor;
+        }
+    }
+
     private static DefaultTaskManager maybeCreateSchedulingTaskManager(final 
boolean processingThreadsEnabled,
                                                                        final 
boolean stateUpdaterEnabled,
                                                                        final 
TopologyMetadata topologyMetadata,
@@ -563,7 +638,98 @@ public class StreamThread extends Thread implements 
ProcessingThread {
         }
     }
 
-    @SuppressWarnings("this-escape")
+    private static Optional<StreamsRebalanceData.HostInfo> parseHostInfo(final 
String endpoint) {
+        final HostInfo hostInfo = HostInfo.buildFromEndpoint(endpoint);
+        if (hostInfo == null) {
+            return Optional.empty();
+        } else {
+            return Optional.of(new 
StreamsRebalanceData.HostInfo(hostInfo.host(), hostInfo.port()));
+        }
+    }
+
+    private static StreamsRebalanceData initStreamsRebalanceData(final UUID 
processId,
+                                                                 final 
StreamsConfig config,
+                                                                 final 
Optional<StreamsRebalanceData.HostInfo> endpoint,
+                                                                 final 
TopologyMetadata topologyMetadata) {
+        final InternalTopologyBuilder internalTopologyBuilder = 
topologyMetadata.lookupBuilderForNamedTopology(null);
+
+        final Map<String, StreamsRebalanceData.Subtopology> subtopologies = 
initBrokerTopology(config, internalTopologyBuilder);
+
+        return new StreamsRebalanceData(
+            processId,
+            endpoint,
+            subtopologies,
+            config.getClientTags()
+        );
+    }
+
+    private static Map<String, StreamsRebalanceData.Subtopology> 
initBrokerTopology(final StreamsConfig config,
+                                                                               
     final InternalTopologyBuilder internalTopologyBuilder) {
+        final Map<String, String> defaultTopicConfigs = new HashMap<>();
+        for (final Map.Entry<String, Object> entry : 
config.originalsWithPrefix(StreamsConfig.TOPIC_PREFIX).entrySet()) {
+            if (entry.getValue() != null) {
+                defaultTopicConfigs.put(entry.getKey(), 
entry.getValue().toString());
+            }
+        }
+        final long windowChangeLogAdditionalRetention = 
config.getLong(StreamsConfig.WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG);
+
+        final Map<String, StreamsRebalanceData.Subtopology> subtopologies = 
new HashMap<>();
+        final Collection<Set<String>> copartitionGroups = 
internalTopologyBuilder.copartitionGroups();
+
+        final Set<String> allRepartitionSourceTopics = 
internalTopologyBuilder.subtopologyToTopicsInfo().values().stream()
+            .flatMap(t -> t.repartitionSourceTopics.keySet().stream())
+            .collect(Collectors.toSet());
+
+        for (final Map.Entry<TopologyMetadata.Subtopology, 
InternalTopologyBuilder.TopicsInfo> topicsInfoEntry : 
internalTopologyBuilder.subtopologyToTopicsInfo()
+            .entrySet()) {
+
+            final HashSet<String> allSourceTopics = new 
HashSet<>(topicsInfoEntry.getValue().sourceTopics);
+            topicsInfoEntry.getValue().repartitionSourceTopics.forEach(
+                (repartitionSourceTopic, repartitionTopicInfo) -> {
+                    allSourceTopics.add(repartitionSourceTopic);
+                });
+
+            final Set<String> sourceTopics = 
topicsInfoEntry.getValue().sourceTopics.stream()
+                .filter(topic -> 
!topicsInfoEntry.getValue().repartitionSourceTopics.containsKey(topic))
+                .collect(Collectors.toSet());
+            final Set<String> repartitionSinkTopics = 
topicsInfoEntry.getValue().sinkTopics.stream()
+                .filter(allRepartitionSourceTopics::contains)
+                .collect(Collectors.toSet());
+            final Map<String, StreamsRebalanceData.TopicInfo> 
repartitionSourceTopics = 
topicsInfoEntry.getValue().repartitionSourceTopics.entrySet()
+                .stream()
+                .collect(Collectors.toMap(Map.Entry::getKey, e ->
+                    new 
StreamsRebalanceData.TopicInfo(e.getValue().numberOfPartitions(),
+                        
Optional.of(config.getInt(StreamsConfig.REPLICATION_FACTOR_CONFIG).shortValue()),
+                        e.getValue().properties(defaultTopicConfigs, 
windowChangeLogAdditionalRetention))));
+            final Map<String, StreamsRebalanceData.TopicInfo> 
stateChangelogTopics = 
topicsInfoEntry.getValue().stateChangelogTopics.entrySet()
+                .stream()
+                .collect(Collectors.toMap(Map.Entry::getKey, e ->
+                    new 
StreamsRebalanceData.TopicInfo(e.getValue().numberOfPartitions(),
+                        
Optional.of(config.getInt(StreamsConfig.REPLICATION_FACTOR_CONFIG).shortValue()),
+                        e.getValue().properties(defaultTopicConfigs, 
windowChangeLogAdditionalRetention))));
+
+            subtopologies.put(
+                String.valueOf(topicsInfoEntry.getKey().nodeGroupId),
+                new StreamsRebalanceData.Subtopology(
+                    sourceTopics,
+                    repartitionSinkTopics,
+                    repartitionSourceTopics,
+                    stateChangelogTopics,
+                    
copartitionGroups.stream().filter(allSourceTopics::containsAll).collect(Collectors.toList())
+                )
+            );
+        }
+
+        if (subtopologies.values().stream().mapToInt(x -> 
x.copartitionGroups().size()).sum()
+            != copartitionGroups.size()) {
+            throw new IllegalStateException(
+                "Not all copartition groups were converted to broker 
topology");
+        }
+
+        return subtopologies;
+    }
+
+    @SuppressWarnings({"this-escape"})
     public StreamThread(final Time time,
                         final StreamsConfig config,
                         final Admin adminClient,
@@ -583,7 +749,10 @@ public class StreamThread extends Thread implements 
ProcessingThread {
                         final Queue<StreamsException> 
nonFatalExceptionsToHandle,
                         final Runnable shutdownErrorHook,
                         final BiConsumer<Throwable, Boolean> 
streamsUncaughtExceptionHandler,
-                        final java.util.function.Consumer<Long> cacheResizer
+                        final java.util.function.Consumer<Long> cacheResizer,
+                        final Optional<StreamsRebalanceData> 
streamsRebalanceData,
+                        final Optional<StreamsRebalanceEventsProcessor> 
streamsRebalanceEventsProcessor,
+                        final StreamsMetadataState streamsMetadataState
                         ) {
         super(threadId);
         this.stateLock = new Object();
@@ -666,6 +835,15 @@ public class StreamThread extends Thread implements 
ProcessingThread {
         this.stateUpdaterEnabled = 
InternalConfig.stateUpdaterEnabled(config.originals());
         this.processingThreadsEnabled = 
InternalConfig.processingThreadsEnabled(config.originals());
         this.logSummaryIntervalMs = 
config.getLong(StreamsConfig.LOG_SUMMARY_INTERVAL_MS_CONFIG);
+
+        this.streamsRebalanceData = streamsRebalanceData;
+        this.streamsRebalanceEventsProcessor = streamsRebalanceEventsProcessor;
+        if (streamsRebalanceData.isPresent() && 
streamsRebalanceEventsProcessor.isPresent()) {
+            streamsRebalanceEventsProcessor.get().setRebalanceCallbacks(
+                new DefaultStreamsGroupRebalanceCallbacks(log, time, 
streamsRebalanceData.get(), this, taskManager)
+            );
+        }
+        this.streamsMetadataState = streamsMetadataState;
     }
 
     private static final class InternalConsumerConfig extends ConsumerConfig {
@@ -961,6 +1139,8 @@ public class StreamThread extends Thread implements 
ProcessingThread {
         final long startMs = time.milliseconds();
         now = startMs;
 
+        maybeHandleAssignmentFromStreamsRebalanceProtocol();
+
         final long pollLatency;
         taskManager.resumePollingForPartitionsWithAvailableSpace();
         pollLatency = pollPhase();
@@ -1296,6 +1476,39 @@ public class StreamThread extends Thread implements 
ProcessingThread {
         return records;
     }
 
+    public void maybeHandleAssignmentFromStreamsRebalanceProtocol() {
+        if (streamsRebalanceData.isPresent()) {
+
+            if (streamsRebalanceData.get().shutdownRequested()) {
+                
assignmentErrorCode.set(AssignorError.SHUTDOWN_REQUESTED.code());
+            }
+
+            // ToDo: process IQ-related metadata
+
+            // Process assignment from Streams Rebalance Protocol
+            streamsRebalanceEventsProcessor.get().process();
+        }
+    }
+
+    static Map<TopicPartition, PartitionInfo> getTopicPartitionInfo(final 
Map<HostInfo, Set<TopicPartition>> partitionsByHost) {
+        final Map<TopicPartition, PartitionInfo> topicToPartitionInfo = new 
HashMap<>();
+        for (final Set<TopicPartition> value : partitionsByHost.values()) {
+            for (final TopicPartition topicPartition : value) {
+                topicToPartitionInfo.put(
+                    topicPartition,
+                    new PartitionInfo(
+                        topicPartition.topic(),
+                        topicPartition.partition(),
+                        null,
+                        new Node[0],
+                        new Node[0]
+                    )
+                );
+            }
+        }
+        return topicToPartitionInfo;
+    }
+
     private void resetOffsets(final Set<TopicPartition> partitions, final 
Exception cause) {
         final Set<String> loggedTopics = new HashSet<>();
         final Set<TopicPartition> seekToBeginning = new HashSet<>();
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/StreamsConfigTest.java 
b/streams/src/test/java/org/apache/kafka/streams/StreamsConfigTest.java
index 7e9447851e1..a7f657ec54b 100644
--- a/streams/src/test/java/org/apache/kafka/streams/StreamsConfigTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/StreamsConfigTest.java
@@ -71,6 +71,7 @@ import static 
org.apache.kafka.streams.StreamsConfig.DSL_STORE_SUPPLIERS_CLASS_C
 import static 
org.apache.kafka.streams.StreamsConfig.ENABLE_METRICS_PUSH_CONFIG;
 import static 
org.apache.kafka.streams.StreamsConfig.ENSURE_EXPLICIT_INTERNAL_RESOURCE_NAMING_CONFIG;
 import static org.apache.kafka.streams.StreamsConfig.EXACTLY_ONCE_V2;
+import static org.apache.kafka.streams.StreamsConfig.GROUP_PROTOCOL_CONFIG;
 import static 
org.apache.kafka.streams.StreamsConfig.MAX_RACK_AWARE_ASSIGNMENT_TAG_KEY_LENGTH;
 import static 
org.apache.kafka.streams.StreamsConfig.MAX_RACK_AWARE_ASSIGNMENT_TAG_VALUE_LENGTH;
 import static 
org.apache.kafka.streams.StreamsConfig.PROCESSOR_WRAPPER_CLASS_CONFIG;
@@ -1596,6 +1597,25 @@ public class StreamsConfigTest {
         
assertTrue(streamsConfig.getBoolean(ENSURE_EXPLICIT_INTERNAL_RESOURCE_NAMING_CONFIG));
     }
 
+    @Test
+    public void shouldSetGroupProtocolToClassicByDefault() {
+        
assertTrue(GroupProtocol.CLASSIC.name().equalsIgnoreCase(streamsConfig.getString(GROUP_PROTOCOL_CONFIG)));
+    }
+
+    @Test
+    public void shouldSetGroupProtocolToClassic() {
+        props.put(GROUP_PROTOCOL_CONFIG, GroupProtocol.CLASSIC.name());
+        streamsConfig = new StreamsConfig(props);
+        
assertTrue(GroupProtocol.CLASSIC.name().equalsIgnoreCase(streamsConfig.getString(GROUP_PROTOCOL_CONFIG)));
+    }
+
+    @Test
+    public void shouldSetGroupProtocolToStreams() {
+        props.put(GROUP_PROTOCOL_CONFIG, GroupProtocol.STREAMS.name());
+        streamsConfig = new StreamsConfig(props);
+        
assertTrue(GroupProtocol.STREAMS.name().equalsIgnoreCase(streamsConfig.getString(GROUP_PROTOCOL_CONFIG)));
+    }
+
     static class MisconfiguredSerde implements Serde<Object> {
         @Override
         public void configure(final Map<String, ?>  configs, final boolean 
isKey) {
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 5d32efe8329..fa72b116f00 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -1447,6 +1447,9 @@ public class StreamThreadTest {
             new LinkedList<>(),
             null,
             HANDLER,
+            null,
+            Optional.empty(),
+            Optional.empty(),
             null
         ).updateThreadMetadata(adminClientId(CLIENT_ID));
 
@@ -2666,6 +2669,9 @@ public class StreamThreadTest {
             new LinkedList<>(),
             null,
             HANDLER,
+            null,
+            Optional.empty(),
+            Optional.empty(),
             null
         ) {
             @Override
@@ -2725,6 +2731,9 @@ public class StreamThreadTest {
             new LinkedList<>(),
             null,
             HANDLER,
+            null,
+            Optional.empty(),
+            Optional.empty(),
             null
         ) {
             @Override
@@ -2793,6 +2802,9 @@ public class StreamThreadTest {
             new LinkedList<>(),
             null,
             HANDLER,
+            null,
+            Optional.empty(),
+            Optional.empty(),
             null
         ) {
             @Override
@@ -2857,6 +2869,9 @@ public class StreamThreadTest {
             new LinkedList<>(),
             null,
             HANDLER,
+            null,
+            Optional.empty(),
+            Optional.empty(),
             null
         ) {
             @Override
@@ -2918,6 +2933,9 @@ public class StreamThreadTest {
             new LinkedList<>(),
             null,
             HANDLER,
+            null,
+            Optional.empty(),
+            Optional.empty(),
             null
         ) {
             @Override
@@ -3150,6 +3168,9 @@ public class StreamThreadTest {
             new LinkedList<>(),
             null,
             HANDLER,
+            null,
+            Optional.empty(),
+            Optional.empty(),
             null
         );
         final MetricName testMetricName = new MetricName("test_metric", "", 
"", new HashMap<>());
@@ -3207,6 +3228,9 @@ public class StreamThreadTest {
             new LinkedList<>(),
             null,
             (e, b) -> { },
+            null,
+            Optional.empty(),
+            Optional.empty(),
             null
         ) {
             @Override
@@ -3587,6 +3611,9 @@ public class StreamThreadTest {
             new LinkedList<>(),
             null,
             null,
+            null,
+            Optional.empty(),
+            Optional.empty(),
             null
         );
     }
@@ -3709,6 +3736,9 @@ public class StreamThreadTest {
             new LinkedList<>(),
             null,
             HANDLER,
+            null,
+            Optional.empty(),
+            Optional.empty(),
             null
         );
     }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java 
b/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java
index bf53c76921b..5bc679245d7 100644
--- a/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java
+++ b/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java
@@ -129,7 +129,9 @@ public class SmokeTestClient extends SmokeTestUtil {
 
     private Properties getStreamsConfig(final Properties props) {
         final Properties fullProps = new Properties(props);
-        fullProps.put(StreamsConfig.APPLICATION_ID_CONFIG, "SmokeTest");
+        if (!props.containsKey(StreamsConfig.APPLICATION_ID_CONFIG)) {
+            fullProps.put(StreamsConfig.APPLICATION_ID_CONFIG, "SmokeTest");
+        }
         fullProps.put(StreamsConfig.CLIENT_ID_CONFIG, "SmokeTest-" + name);
         fullProps.put(StreamsConfig.STATE_DIR_CONFIG, 
TestUtils.tempDirectory().getAbsolutePath());
         fullProps.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, 
StreamsConfig.EXACTLY_ONCE_V2);

Reply via email to