dajac commented on code in PR #14640:
URL: https://github.com/apache/kafka/pull/14640#discussion_r1418677356


##########
clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java:
##########
@@ -1497,6 +1606,66 @@ private void subscribeInternal(Collection<String> 
topics, Optional<ConsumerRebal
         }
     }
 
+    /**
+     * This method can be used by cases where the caller has an event that 
needs to both block for completion but
+     * also process background events. For some events, in order to fully 
process the associated logic, the
+     * {@link ConsumerNetworkThread background thread} needs assistance from 
the application thread to complete.
+     * If the application thread simply blocked on the event after submitting 
it, the processing would deadlock.
+     * The logic herein is basically a loop that performs two tasks in each 
iteration:
+     *
+     * <ol>
+     *     <li>Process background events, if any</li>
+     *     <li><em>Briefly</em> wait for {@link CompletableApplicationEvent an 
event} to complete</li>
+     * </ol>
+     *
+     * <p/>
+     *
+     * Each iteration gives the application thread an opportunity to process 
background events, which may be
+     * necessary to complete the overall processing.
+     *
+     * <p/>
+     *
+     * As an example, take {@link #unsubscribe()}. To start unsubscribing, the 
application thread enqueues an
+     * {@link UnsubscribeApplicationEvent} on the application event queue. 
That event will eventually trigger the
+     * rebalancing logic in the background thread. Critically, as part of this 
rebalancing work, the
+     * {@link ConsumerRebalanceListener#onPartitionsRevoked(Collection)} 
callback needs to be invoked. However,
+     * this callback must be executed on the application thread. To achieve 
this, the background thread enqueues a
+     * {@link ConsumerRebalanceListenerCallbackNeededEvent} on its background 
event queue. That event queue is
+     * periodically queried by the application thread to see if there's work 
to be done. When the application thread
+     * sees {@link ConsumerRebalanceListenerCallbackNeededEvent}, it is 
processed, and then a
+     * {@link ConsumerRebalanceListenerCallbackCompletedEvent} is then 
enqueued by the application thread on the
+     * background event queue. Moments later, the background thread will see 
that event, process it, and continue
+     * execution of the rebalancing logic. The rebalancing logic cannot 
complete until the
+     * {@link ConsumerRebalanceListener} callback is performed.
+     *
+     * @param event Event that contains a {@link CompletableFuture}; it is on 
this future that the application thread
+     *              will wait for completion
+     * @param timer Overall timer that bounds how long the application thread 
will wait for the event to complete
+     * @return {@code true} if the event completed within the timeout, {@code 
false} otherwise
+     */
+    private boolean processBackgroundEvents(CompletableApplicationEvent<?> 
event, Timer timer) {
+        log.trace("Enqueuing event {} for processing; will wait up to {} ms to 
complete", event, timer.remainingMs());
+
+        do {
+            backgroundEventProcessor.process();
+
+            try {
+                Timer pollInterval = time.timer(100L);

Review Comment:
   I wonder if we should rather wait only when there are no events in the queue 
waiting to be processed. In the worst case, the call to `process()` could have 
just missed the event that we need so we wait 100ms for nothing. Have you 
considered this?



##########
clients/src/main/java/org/apache/kafka/clients/consumer/internals/MembershipManagerImpl.java:
##########
@@ -971,12 +989,59 @@ private CompletableFuture<Void> 
invokeOnPartitionsLostCallback(Set<TopicPartitio
         // behaviour.
         Optional<ConsumerRebalanceListener> listener = 
subscriptions.rebalanceListener();
         if (!partitionsLost.isEmpty() && listener.isPresent()) {
-            throw new UnsupportedOperationException("User-defined callbacks 
not supported yet");
+            return 
enqueueConsumerRebalanceListenerCallback(ON_PARTITIONS_LOST, partitionsLost);
         } else {
             return CompletableFuture.completedFuture(null);
         }
     }
 
+    /**
+     * Enqueue a {@link ConsumerRebalanceListenerCallbackNeededEvent} to 
trigger the execution of the
+     * appropriate {@link ConsumerRebalanceListener} {@link 
ConsumerRebalanceListenerMethodName method} on the
+     * application thread.
+     *
+     * <p/>
+     *
+     * Because the reconciliation process (run in the background thread) will 
be blocked by the application thread
+     * until it completes this, we need to provide a {@link CompletableFuture} 
by which to remember where we left off.
+     *
+     * @param methodName Callback method that needs to be executed on the 
application thread
+     * @param partitions Partitions to supply to the callback method
+     * @return Future that will be chained within the rest of the 
reconciliation logic
+     */
+    private CompletableFuture<Void> 
enqueueConsumerRebalanceListenerCallback(ConsumerRebalanceListenerMethodName 
methodName,
+                                                                             
Set<TopicPartition> partitions) {
+        SortedSet<TopicPartition> sortedPartitions = new 
TreeSet<>(TOPIC_PARTITION_COMPARATOR);
+        sortedPartitions.addAll(partitions);
+        CompletableBackgroundEvent<Void> event = new 
ConsumerRebalanceListenerCallbackNeededEvent(methodName, sortedPartitions);
+        backgroundEventHandler.add(event);
+        log.debug("The event to trigger the {} method execution was enqueued 
successfully", methodName.fullyQualifiedMethodName());
+        return event.future();
+    }
+
+    @Override
+    public void 
consumerRebalanceListenerCallbackCompleted(ConsumerRebalanceListenerCallbackCompletedEvent
 event) {
+        ConsumerRebalanceListenerMethodName methodName = event.methodName();
+        Optional<KafkaException> error = event.error();
+        CompletableFuture<Void> future = event.future();
+
+        if (error.isPresent()) {
+            String message = error.get().getMessage();
+            log.warn(
+                "The {} method completed with an error ({}); signaling to 
continue to the next phase of rebalance",
+                methodName.fullyQualifiedMethodName(),
+                message
+            );
+        } else {
+            log.debug(
+                "The {} method completed successfully; signaling to continue 
to the next phase of rebalance",
+                methodName.fullyQualifiedMethodName()
+            );
+        }
+
+        future.complete(null);

Review Comment:
   It looks like we don't propagate the `error`. I am not sure about this. My 
understanding is that, in the legacy consumer, an exception thrown by the 
callback would stop the rebalance. I wonder if we should do the same here in 
order to interrupt and retry the reconciliation. @lianetm If we propagate the 
exception here, is the current logic going to retry? I suppose so.



##########
clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessor.java:
##########
@@ -180,12 +179,12 @@ private void process(final ListOffsetsApplicationEvent 
event) {
      * consumer join the group if it is not part of it yet, or send the 
updated subscription if
      * it is already a member.
      */
-    private void process(final SubscriptionChangeApplicationEvent event) {
-        if (!requestManagers.membershipManager.isPresent()) {
-            throw new RuntimeException("Group membership manager not present 
when processing a " +
-                    "subscribe event");
+    private void process(final SubscriptionChangeApplicationEvent ignored) {
+        if (!requestManagers.heartbeatRequestManager.isPresent()) {

Review Comment:
   Why did we change this? The request managers has a reference to the 
membership manager. It seems simpler to just get it from there. This would also 
allow us to remove the `membershipManager` method in the heartbeat manager.



##########
clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessor.java:
##########
@@ -198,11 +197,12 @@ private void process(final 
SubscriptionChangeApplicationEvent event) {
      *              the group is sent out.
      */
     private void process(final UnsubscribeApplicationEvent event) {
-        if (!requestManagers.membershipManager.isPresent()) {
-            throw new RuntimeException("Group membership manager not present 
when processing an " +
-                    "unsubscribe event");
+        if (!requestManagers.heartbeatRequestManager.isPresent()) {

Review Comment:
   Same question here.



##########
clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ConsumerRebalanceListenerCallbackCompletedEvent.java:
##########
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals.events;
+
+import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
+import 
org.apache.kafka.clients.consumer.internals.ConsumerRebalanceListenerMethodName;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.KafkaException;
+
+import java.util.Collections;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.SortedSet;
+import java.util.concurrent.CompletableFuture;
+
+/**
+ * Event that signifies that the application thread has executed the {@link 
ConsumerRebalanceListener} callback. If
+ * the callback execution threw an error, it is included in the event should 
any event listener want to know.
+ */
+public class ConsumerRebalanceListenerCallbackCompletedEvent extends 
ApplicationEvent {
+
+    private final ConsumerRebalanceListenerMethodName methodName;
+    private final SortedSet<TopicPartition> partitions;
+    private final CompletableFuture<Void> future;
+    private final Optional<KafkaException> error;
+
+    public 
ConsumerRebalanceListenerCallbackCompletedEvent(ConsumerRebalanceListenerMethodName
 methodName,
+                                                           
SortedSet<TopicPartition> partitions,

Review Comment:
   We pass the method name and the partitions here but we don't use them except 
in `toString()`. Do we need them?



##########
clients/src/test/java/org/apache/kafka/clients/consumer/internals/MembershipManagerImplTest.java:
##########
@@ -790,6 +812,197 @@ public void 
testOnSubscriptionUpdatedTransitionsToJoiningOnlyIfNotInGroup() {
         verify(membershipManager, never()).transitionToJoining();
     }
 
+    @Test
+    public void testListenerCallbacksBasic() {
+        // Step 1: set up mocks
+        MembershipManagerImpl membershipManager = createMemberInStableState();
+        mockOwnedPartition("topic1", 0);
+        CounterConsumerRebalanceListener listener = new 
CounterConsumerRebalanceListener();
+        ConsumerRebalanceListenerInvoker invoker = 
consumerRebalanceListenerInvoker();
+
+        
when(subscriptionState.rebalanceListener()).thenReturn(Optional.of(listener));
+        doNothing().when(subscriptionState).markPendingRevocation(anySet());
+
+        // Step 2: put the state machine into the appropriate... state
+        receiveEmptyAssignment(membershipManager);
+        assertEquals(MemberState.RECONCILING, membershipManager.state());
+        assertEquals(Collections.emptySet(), 
membershipManager.currentAssignment());
+        assertTrue(membershipManager.reconciliationInProgress());
+        listener.assertCounts(0, 0, 0);
+
+        // Step 2: revoke partitions
+        performCallback(
+                membershipManager,
+                invoker,
+                ConsumerRebalanceListenerMethodName.ON_PARTITIONS_REVOKED,
+                topicPartitions(new TopicPartition("topic1", 0))
+        );
+
+        // Step 3: assign partitions
+        performCallback(
+                membershipManager,
+                invoker,
+                ConsumerRebalanceListenerMethodName.ON_PARTITIONS_ASSIGNED,
+                Collections.emptySortedSet()
+        );
+
+        // Step 4: Receive ack and make sure we're done and our listener was 
called appropriately
+        membershipManager.onHeartbeatRequestSent();
+        assertEquals(MemberState.STABLE, membershipManager.state());
+
+        listener.assertCounts(1, 1, 0);
+    }
+
+    @Test
+    public void testListenerCallbacksNoListeners() {
+        // Step 1: set up mocks
+        MembershipManagerImpl membershipManager = createMemberInStableState();
+        mockOwnedPartition("topic1", 0);
+        CounterConsumerRebalanceListener listener = new 
CounterConsumerRebalanceListener();
+
+        
when(subscriptionState.rebalanceListener()).thenReturn(Optional.empty());
+        doNothing().when(subscriptionState).markPendingRevocation(anySet());
+
+        // Step 2: put the state machine into the appropriate... state
+        receiveEmptyAssignment(membershipManager);
+        assertEquals(MemberState.ACKNOWLEDGING, membershipManager.state());
+        assertEquals(Collections.emptySet(), 
membershipManager.currentAssignment());
+        assertFalse(membershipManager.reconciliationInProgress());
+        assertEquals(0, backgroundEventQueue.size());
+        listener.assertCounts(0, 0, 0);
+
+        // Step 3: Receive ack and make sure we're done and our listener was 
called appropriately
+        membershipManager.onHeartbeatRequestSent();
+        assertEquals(MemberState.STABLE, membershipManager.state());
+
+        listener.assertCounts(0, 0, 0);
+    }
+
+    @Test
+    public void testOnPartitionsLostNoError() {
+        mockOwnedPartition("topic1", 0);
+        testOnPartitionsLost(Optional.empty());
+    }
+
+    @Test
+    public void testOnPartitionsLostError() {
+        mockOwnedPartition("topic1", 0);
+        testOnPartitionsLost(Optional.of(new KafkaException("Intentional error 
for test")));
+    }
+
+    private void testOnPartitionsLost(Optional<RuntimeException> lostError) {
+        // Step 1: set up mocks
+        MembershipManagerImpl membershipManager = createMemberInStableState();
+        CounterConsumerRebalanceListener listener = new 
CounterConsumerRebalanceListener(
+                Optional.empty(),
+                Optional.empty(),
+                lostError
+        );
+        ConsumerRebalanceListenerInvoker invoker = 
consumerRebalanceListenerInvoker();
+
+        
when(subscriptionState.rebalanceListener()).thenReturn(Optional.of(listener));
+        doNothing().when(subscriptionState).markPendingRevocation(anySet());
+
+        // Step 2: put the state machine into the appropriate... state
+        membershipManager.transitionToFenced();
+        assertEquals(MemberState.FENCED, membershipManager.state());
+        assertEquals(Collections.emptySet(), 
membershipManager.currentAssignment());
+        listener.assertCounts(0, 0, 0);
+
+        // Step 3: invoke the callback
+        performCallback(
+                membershipManager,
+                invoker,
+                ConsumerRebalanceListenerMethodName.ON_PARTITIONS_LOST,
+                topicPartitions(new TopicPartition("topic1", 0))
+        );
+
+        // Step 4: Receive ack and make sure we're done and our listener was 
called appropriately
+        membershipManager.onHeartbeatRequestSent();
+        assertEquals(MemberState.JOINING, membershipManager.state());
+
+        listener.assertCounts(0, 0, 1);
+    }
+
+    private ConsumerRebalanceListenerInvoker 
consumerRebalanceListenerInvoker() {
+        ConsumerCoordinatorMetrics coordinatorMetrics = new 
ConsumerCoordinatorMetrics(
+                subscriptionState,
+                new Metrics(),
+                "test-");
+        return new ConsumerRebalanceListenerInvoker(
+                new LogContext(),
+                subscriptionState,
+                new MockTime(1),
+                coordinatorMetrics
+        );
+    }
+
+    private SortedSet<TopicPartition> topicPartitions(TopicPartition... 
topicPartitions) {
+        SortedSet<TopicPartition> revokedPartitions = new TreeSet<>(new 
Utils.TopicPartitionComparator());
+        revokedPartitions.addAll(Arrays.asList(topicPartitions));
+        return revokedPartitions;
+    }
+
+    private void performCallback(MembershipManagerImpl membershipManager,
+                                 ConsumerRebalanceListenerInvoker invoker,
+                                 ConsumerRebalanceListenerMethodName 
methodName,
+                                 SortedSet<TopicPartition> partitions) {
+        // Set up our mock application event handler & background event 
processor.
+        ApplicationEventHandler applicationEventHandler = 
mock(ApplicationEventHandler.class);
+
+        doAnswer(a -> {
+            ConsumerRebalanceListenerCallbackCompletedEvent completedEvent = 
a.getArgument(0);
+            
membershipManager.consumerRebalanceListenerCallbackCompleted(completedEvent);
+            return null;
+        
}).when(applicationEventHandler).add(any(ConsumerRebalanceListenerCallbackCompletedEvent.class));
+
+        // We expect only our enqueued event in the background queue.
+        assertEquals(1, backgroundEventQueue.size());
+        assertNotNull(backgroundEventQueue.peek());
+        assertInstanceOf(ConsumerRebalanceListenerCallbackNeededEvent.class, 
backgroundEventQueue.peek());
+        ConsumerRebalanceListenerCallbackNeededEvent neededEvent = 
(ConsumerRebalanceListenerCallbackNeededEvent) backgroundEventQueue.poll();
+        assertNotNull(neededEvent);
+        assertEquals(methodName, neededEvent.methodName());
+        assertEquals(partitions, neededEvent.partitions());
+
+        final Exception e;
+
+        switch (methodName) {
+            case ON_PARTITIONS_REVOKED:
+                e = invoker.invokePartitionsRevoked(partitions);
+                break;
+
+            case ON_PARTITIONS_ASSIGNED:
+                e = invoker.invokePartitionsAssigned(partitions);
+                break;
+
+            case ON_PARTITIONS_LOST:
+                e = invoker.invokePartitionsLost(partitions);
+                break;
+
+            default:
+                throw new IllegalArgumentException("The method " + methodName 
+ " to invoke was not expected");
+        }
+
+        final Optional<KafkaException> error;
+
+        if (e != null) {
+            if (e instanceof KafkaException)
+                error = Optional.of((KafkaException) e);
+            else
+                error = Optional.of(new KafkaException("User rebalance 
callback throws an error", e));
+        } else {
+            error = Optional.empty();
+        }

Review Comment:
   Is this actually needed? We don't really care about the `invoker` in this 
context. We only care about validating the received event and pushing the 
completed event.



##########
clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java:
##########
@@ -42,7 +42,11 @@
 import 
org.apache.kafka.clients.consumer.internals.events.ApplicationEventProcessor;

Review Comment:
   Do we have any new tests to cover the changes in this class?



##########
core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala:
##########
@@ -169,7 +169,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
       startingTimestamp = startingTimestamp)

Review Comment:
   So we cannot enable any of those tests because we still miss other features. 
Is my understanding correct?



##########
clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessor.java:
##########
@@ -113,16 +113,15 @@ public void process(ApplicationEvent event) {
                 process((UnsubscribeApplicationEvent) event);
                 return;
 
+            case CONSUMER_REBALANCE_LISTENER_CALLBACK_COMPLETED:

Review Comment:
   Do we have tests to cover changes made in this class? I see that 
ApplicationEventProcessorTest does not even exist. I think that it would be 
good to have unit tests for it.



##########
clients/src/test/java/org/apache/kafka/clients/consumer/internals/MembershipManagerImplTest.java:
##########
@@ -790,6 +812,197 @@ public void 
testOnSubscriptionUpdatedTransitionsToJoiningOnlyIfNotInGroup() {
         verify(membershipManager, never()).transitionToJoining();
     }
 
+    @Test
+    public void testListenerCallbacksBasic() {

Review Comment:
   Should we add the following tests?
   * Calling on partitions assigned with partitions.
   * Receiving an error when on revoked is called.
   * Receiving an error when on assigned is called.



##########
clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessor.java:
##########
@@ -223,6 +223,18 @@ private void process(final TopicMetadataApplicationEvent 
event) {
         event.chain(future);
     }
 
+    private void process(final ConsumerRebalanceListenerCallbackCompletedEvent 
event) {
+        if (!requestManagers.heartbeatRequestManager.isPresent()) {

Review Comment:
   And here.



##########
clients/src/test/java/org/apache/kafka/clients/consumer/internals/MembershipManagerImplTest.java:
##########
@@ -790,6 +812,197 @@ public void 
testOnSubscriptionUpdatedTransitionsToJoiningOnlyIfNotInGroup() {
         verify(membershipManager, never()).transitionToJoining();
     }
 
+    @Test
+    public void testListenerCallbacksBasic() {
+        // Step 1: set up mocks
+        MembershipManagerImpl membershipManager = createMemberInStableState();
+        mockOwnedPartition("topic1", 0);
+        CounterConsumerRebalanceListener listener = new 
CounterConsumerRebalanceListener();
+        ConsumerRebalanceListenerInvoker invoker = 
consumerRebalanceListenerInvoker();
+
+        
when(subscriptionState.rebalanceListener()).thenReturn(Optional.of(listener));
+        doNothing().when(subscriptionState).markPendingRevocation(anySet());
+
+        // Step 2: put the state machine into the appropriate... state
+        receiveEmptyAssignment(membershipManager);
+        assertEquals(MemberState.RECONCILING, membershipManager.state());
+        assertEquals(Collections.emptySet(), 
membershipManager.currentAssignment());
+        assertTrue(membershipManager.reconciliationInProgress());
+        listener.assertCounts(0, 0, 0);
+
+        // Step 2: revoke partitions
+        performCallback(
+                membershipManager,
+                invoker,
+                ConsumerRebalanceListenerMethodName.ON_PARTITIONS_REVOKED,
+                topicPartitions(new TopicPartition("topic1", 0))
+        );
+
+        // Step 3: assign partitions
+        performCallback(
+                membershipManager,
+                invoker,
+                ConsumerRebalanceListenerMethodName.ON_PARTITIONS_ASSIGNED,
+                Collections.emptySortedSet()
+        );
+
+        // Step 4: Receive ack and make sure we're done and our listener was 
called appropriately
+        membershipManager.onHeartbeatRequestSent();
+        assertEquals(MemberState.STABLE, membershipManager.state());
+
+        listener.assertCounts(1, 1, 0);
+    }
+
+    @Test
+    public void testListenerCallbacksNoListeners() {
+        // Step 1: set up mocks
+        MembershipManagerImpl membershipManager = createMemberInStableState();
+        mockOwnedPartition("topic1", 0);
+        CounterConsumerRebalanceListener listener = new 
CounterConsumerRebalanceListener();

Review Comment:
   This test is weird, no? `listener` is not used anywhere so of course it will 
always have zero values internally.



##########
clients/src/test/java/org/apache/kafka/clients/consumer/internals/MembershipManagerImplTest.java:
##########
@@ -790,6 +812,197 @@ public void 
testOnSubscriptionUpdatedTransitionsToJoiningOnlyIfNotInGroup() {
         verify(membershipManager, never()).transitionToJoining();
     }
 
+    @Test
+    public void testListenerCallbacksBasic() {
+        // Step 1: set up mocks
+        MembershipManagerImpl membershipManager = createMemberInStableState();
+        mockOwnedPartition("topic1", 0);
+        CounterConsumerRebalanceListener listener = new 
CounterConsumerRebalanceListener();
+        ConsumerRebalanceListenerInvoker invoker = 
consumerRebalanceListenerInvoker();
+
+        
when(subscriptionState.rebalanceListener()).thenReturn(Optional.of(listener));
+        doNothing().when(subscriptionState).markPendingRevocation(anySet());
+
+        // Step 2: put the state machine into the appropriate... state
+        receiveEmptyAssignment(membershipManager);
+        assertEquals(MemberState.RECONCILING, membershipManager.state());
+        assertEquals(Collections.emptySet(), 
membershipManager.currentAssignment());
+        assertTrue(membershipManager.reconciliationInProgress());
+        listener.assertCounts(0, 0, 0);
+
+        // Step 2: revoke partitions
+        performCallback(
+                membershipManager,
+                invoker,
+                ConsumerRebalanceListenerMethodName.ON_PARTITIONS_REVOKED,
+                topicPartitions(new TopicPartition("topic1", 0))
+        );
+
+        // Step 3: assign partitions
+        performCallback(
+                membershipManager,
+                invoker,
+                ConsumerRebalanceListenerMethodName.ON_PARTITIONS_ASSIGNED,
+                Collections.emptySortedSet()
+        );
+
+        // Step 4: Receive ack and make sure we're done and our listener was 
called appropriately
+        membershipManager.onHeartbeatRequestSent();
+        assertEquals(MemberState.STABLE, membershipManager.state());
+
+        listener.assertCounts(1, 1, 0);
+    }
+
+    @Test
+    public void testListenerCallbacksNoListeners() {
+        // Step 1: set up mocks
+        MembershipManagerImpl membershipManager = createMemberInStableState();
+        mockOwnedPartition("topic1", 0);
+        CounterConsumerRebalanceListener listener = new 
CounterConsumerRebalanceListener();
+
+        
when(subscriptionState.rebalanceListener()).thenReturn(Optional.empty());
+        doNothing().when(subscriptionState).markPendingRevocation(anySet());
+
+        // Step 2: put the state machine into the appropriate... state
+        receiveEmptyAssignment(membershipManager);
+        assertEquals(MemberState.ACKNOWLEDGING, membershipManager.state());
+        assertEquals(Collections.emptySet(), 
membershipManager.currentAssignment());
+        assertFalse(membershipManager.reconciliationInProgress());
+        assertEquals(0, backgroundEventQueue.size());
+        listener.assertCounts(0, 0, 0);
+
+        // Step 3: Receive ack and make sure we're done and our listener was 
called appropriately
+        membershipManager.onHeartbeatRequestSent();
+        assertEquals(MemberState.STABLE, membershipManager.state());
+
+        listener.assertCounts(0, 0, 0);
+    }
+
+    @Test
+    public void testOnPartitionsLostNoError() {
+        mockOwnedPartition("topic1", 0);
+        testOnPartitionsLost(Optional.empty());
+    }
+
+    @Test
+    public void testOnPartitionsLostError() {
+        mockOwnedPartition("topic1", 0);
+        testOnPartitionsLost(Optional.of(new KafkaException("Intentional error 
for test")));
+    }
+
+    private void testOnPartitionsLost(Optional<RuntimeException> lostError) {
+        // Step 1: set up mocks
+        MembershipManagerImpl membershipManager = createMemberInStableState();
+        CounterConsumerRebalanceListener listener = new 
CounterConsumerRebalanceListener(
+                Optional.empty(),
+                Optional.empty(),
+                lostError
+        );
+        ConsumerRebalanceListenerInvoker invoker = 
consumerRebalanceListenerInvoker();
+
+        
when(subscriptionState.rebalanceListener()).thenReturn(Optional.of(listener));
+        doNothing().when(subscriptionState).markPendingRevocation(anySet());
+
+        // Step 2: put the state machine into the appropriate... state
+        membershipManager.transitionToFenced();
+        assertEquals(MemberState.FENCED, membershipManager.state());
+        assertEquals(Collections.emptySet(), 
membershipManager.currentAssignment());
+        listener.assertCounts(0, 0, 0);
+
+        // Step 3: invoke the callback
+        performCallback(
+                membershipManager,
+                invoker,
+                ConsumerRebalanceListenerMethodName.ON_PARTITIONS_LOST,
+                topicPartitions(new TopicPartition("topic1", 0))
+        );
+
+        // Step 4: Receive ack and make sure we're done and our listener was 
called appropriately
+        membershipManager.onHeartbeatRequestSent();
+        assertEquals(MemberState.JOINING, membershipManager.state());
+
+        listener.assertCounts(0, 0, 1);
+    }
+
+    private ConsumerRebalanceListenerInvoker 
consumerRebalanceListenerInvoker() {
+        ConsumerCoordinatorMetrics coordinatorMetrics = new 
ConsumerCoordinatorMetrics(
+                subscriptionState,
+                new Metrics(),
+                "test-");
+        return new ConsumerRebalanceListenerInvoker(
+                new LogContext(),
+                subscriptionState,
+                new MockTime(1),
+                coordinatorMetrics
+        );
+    }
+
+    private SortedSet<TopicPartition> topicPartitions(TopicPartition... 
topicPartitions) {
+        SortedSet<TopicPartition> revokedPartitions = new TreeSet<>(new 
Utils.TopicPartitionComparator());
+        revokedPartitions.addAll(Arrays.asList(topicPartitions));
+        return revokedPartitions;
+    }
+
+    private void performCallback(MembershipManagerImpl membershipManager,
+                                 ConsumerRebalanceListenerInvoker invoker,
+                                 ConsumerRebalanceListenerMethodName 
methodName,
+                                 SortedSet<TopicPartition> partitions) {

Review Comment:
   Should we prefix those two with `expected`?



##########
clients/src/test/java/org/apache/kafka/clients/consumer/internals/MembershipManagerImplTest.java:
##########
@@ -790,6 +812,197 @@ public void 
testOnSubscriptionUpdatedTransitionsToJoiningOnlyIfNotInGroup() {
         verify(membershipManager, never()).transitionToJoining();
     }
 
+    @Test
+    public void testListenerCallbacksBasic() {
+        // Step 1: set up mocks
+        MembershipManagerImpl membershipManager = createMemberInStableState();
+        mockOwnedPartition("topic1", 0);
+        CounterConsumerRebalanceListener listener = new 
CounterConsumerRebalanceListener();
+        ConsumerRebalanceListenerInvoker invoker = 
consumerRebalanceListenerInvoker();
+
+        
when(subscriptionState.rebalanceListener()).thenReturn(Optional.of(listener));
+        doNothing().when(subscriptionState).markPendingRevocation(anySet());
+
+        // Step 2: put the state machine into the appropriate... state
+        receiveEmptyAssignment(membershipManager);
+        assertEquals(MemberState.RECONCILING, membershipManager.state());
+        assertEquals(Collections.emptySet(), 
membershipManager.currentAssignment());
+        assertTrue(membershipManager.reconciliationInProgress());
+        listener.assertCounts(0, 0, 0);
+
+        // Step 2: revoke partitions
+        performCallback(
+                membershipManager,
+                invoker,
+                ConsumerRebalanceListenerMethodName.ON_PARTITIONS_REVOKED,
+                topicPartitions(new TopicPartition("topic1", 0))
+        );
+
+        // Step 3: assign partitions
+        performCallback(
+                membershipManager,
+                invoker,
+                ConsumerRebalanceListenerMethodName.ON_PARTITIONS_ASSIGNED,
+                Collections.emptySortedSet()
+        );
+
+        // Step 4: Receive ack and make sure we're done and our listener was 
called appropriately
+        membershipManager.onHeartbeatRequestSent();
+        assertEquals(MemberState.STABLE, membershipManager.state());
+
+        listener.assertCounts(1, 1, 0);
+    }
+
+    @Test
+    public void testListenerCallbacksNoListeners() {
+        // Step 1: set up mocks
+        MembershipManagerImpl membershipManager = createMemberInStableState();
+        mockOwnedPartition("topic1", 0);
+        CounterConsumerRebalanceListener listener = new 
CounterConsumerRebalanceListener();
+
+        
when(subscriptionState.rebalanceListener()).thenReturn(Optional.empty());
+        doNothing().when(subscriptionState).markPendingRevocation(anySet());
+
+        // Step 2: put the state machine into the appropriate... state
+        receiveEmptyAssignment(membershipManager);
+        assertEquals(MemberState.ACKNOWLEDGING, membershipManager.state());
+        assertEquals(Collections.emptySet(), 
membershipManager.currentAssignment());
+        assertFalse(membershipManager.reconciliationInProgress());
+        assertEquals(0, backgroundEventQueue.size());
+        listener.assertCounts(0, 0, 0);
+
+        // Step 3: Receive ack and make sure we're done and our listener was 
called appropriately
+        membershipManager.onHeartbeatRequestSent();
+        assertEquals(MemberState.STABLE, membershipManager.state());
+
+        listener.assertCounts(0, 0, 0);
+    }
+
+    @Test
+    public void testOnPartitionsLostNoError() {
+        mockOwnedPartition("topic1", 0);
+        testOnPartitionsLost(Optional.empty());
+    }
+
+    @Test
+    public void testOnPartitionsLostError() {
+        mockOwnedPartition("topic1", 0);
+        testOnPartitionsLost(Optional.of(new KafkaException("Intentional error 
for test")));
+    }
+
+    private void testOnPartitionsLost(Optional<RuntimeException> lostError) {
+        // Step 1: set up mocks
+        MembershipManagerImpl membershipManager = createMemberInStableState();
+        CounterConsumerRebalanceListener listener = new 
CounterConsumerRebalanceListener(
+                Optional.empty(),
+                Optional.empty(),
+                lostError
+        );
+        ConsumerRebalanceListenerInvoker invoker = 
consumerRebalanceListenerInvoker();
+
+        
when(subscriptionState.rebalanceListener()).thenReturn(Optional.of(listener));
+        doNothing().when(subscriptionState).markPendingRevocation(anySet());
+
+        // Step 2: put the state machine into the appropriate... state
+        membershipManager.transitionToFenced();
+        assertEquals(MemberState.FENCED, membershipManager.state());
+        assertEquals(Collections.emptySet(), 
membershipManager.currentAssignment());
+        listener.assertCounts(0, 0, 0);
+
+        // Step 3: invoke the callback
+        performCallback(
+                membershipManager,
+                invoker,
+                ConsumerRebalanceListenerMethodName.ON_PARTITIONS_LOST,
+                topicPartitions(new TopicPartition("topic1", 0))
+        );
+
+        // Step 4: Receive ack and make sure we're done and our listener was 
called appropriately
+        membershipManager.onHeartbeatRequestSent();
+        assertEquals(MemberState.JOINING, membershipManager.state());
+
+        listener.assertCounts(0, 0, 1);
+    }
+
+    private ConsumerRebalanceListenerInvoker 
consumerRebalanceListenerInvoker() {
+        ConsumerCoordinatorMetrics coordinatorMetrics = new 
ConsumerCoordinatorMetrics(
+                subscriptionState,
+                new Metrics(),
+                "test-");
+        return new ConsumerRebalanceListenerInvoker(
+                new LogContext(),
+                subscriptionState,
+                new MockTime(1),
+                coordinatorMetrics
+        );
+    }
+
+    private SortedSet<TopicPartition> topicPartitions(TopicPartition... 
topicPartitions) {
+        SortedSet<TopicPartition> revokedPartitions = new TreeSet<>(new 
Utils.TopicPartitionComparator());
+        revokedPartitions.addAll(Arrays.asList(topicPartitions));
+        return revokedPartitions;
+    }
+
+    private void performCallback(MembershipManagerImpl membershipManager,
+                                 ConsumerRebalanceListenerInvoker invoker,
+                                 ConsumerRebalanceListenerMethodName 
methodName,
+                                 SortedSet<TopicPartition> partitions) {
+        // Set up our mock application event handler & background event 
processor.
+        ApplicationEventHandler applicationEventHandler = 
mock(ApplicationEventHandler.class);
+
+        doAnswer(a -> {
+            ConsumerRebalanceListenerCallbackCompletedEvent completedEvent = 
a.getArgument(0);
+            
membershipManager.consumerRebalanceListenerCallbackCompleted(completedEvent);
+            return null;
+        
}).when(applicationEventHandler).add(any(ConsumerRebalanceListenerCallbackCompletedEvent.class));
+
+        // We expect only our enqueued event in the background queue.
+        assertEquals(1, backgroundEventQueue.size());
+        assertNotNull(backgroundEventQueue.peek());
+        assertInstanceOf(ConsumerRebalanceListenerCallbackNeededEvent.class, 
backgroundEventQueue.peek());
+        ConsumerRebalanceListenerCallbackNeededEvent neededEvent = 
(ConsumerRebalanceListenerCallbackNeededEvent) backgroundEventQueue.poll();
+        assertNotNull(neededEvent);
+        assertEquals(methodName, neededEvent.methodName());
+        assertEquals(partitions, neededEvent.partitions());
+
+        final Exception e;
+
+        switch (methodName) {
+            case ON_PARTITIONS_REVOKED:
+                e = invoker.invokePartitionsRevoked(partitions);
+                break;
+
+            case ON_PARTITIONS_ASSIGNED:
+                e = invoker.invokePartitionsAssigned(partitions);
+                break;
+
+            case ON_PARTITIONS_LOST:
+                e = invoker.invokePartitionsLost(partitions);
+                break;
+
+            default:
+                throw new IllegalArgumentException("The method " + methodName 
+ " to invoke was not expected");
+        }
+
+        final Optional<KafkaException> error;
+
+        if (e != null) {
+            if (e instanceof KafkaException)
+                error = Optional.of((KafkaException) e);
+            else
+                error = Optional.of(new KafkaException("User rebalance 
callback throws an error", e));
+        } else {
+            error = Optional.empty();
+        }
+
+        ApplicationEvent invokedEvent = new 
ConsumerRebalanceListenerCallbackCompletedEvent(
+                methodName,
+                partitions,
+                neededEvent.future(),
+                error);
+        applicationEventHandler.add(invokedEvent);

Review Comment:
   Should we just call `consumerRebalanceListenerCallbackCompleted()` instead 
of going through the application event handler?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to