This is an automated email from the ASF dual-hosted git repository.

lhotari pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/pulsar.git

commit 03de43e0b9405578fc07857fdf0db2bcf744082e
Author: Lari Hotari <[email protected]>
AuthorDate: Wed May 14 10:00:26 2025 +0300

    [fix][test] Fix more Netty ByteBuf leaks in tests (#24299)
    
    (cherry picked from commit 174245d322574ef422f08a4c249db4e649010c83)
---
 .../pulsar/broker/service/EntryAndMetadata.java    |   2 +-
 ...ntStickyKeyDispatcherMultipleConsumersTest.java | 495 ++++++++++++---------
 .../impl/AcknowledgementsGroupingTrackerTest.java  |  45 +-
 .../client/impl/BinaryProtoLookupServiceTest.java  |  15 +-
 .../apache/pulsar/client/impl/ClientCnxTest.java   |  56 +--
 .../pulsar/client/impl/ClientTestFixtures.java     |  92 +++-
 6 files changed, 420 insertions(+), 285 deletions(-)

diff --git 
a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/EntryAndMetadata.java
 
b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/EntryAndMetadata.java
index 70643d5de2a..d9cdedf6477 100644
--- 
a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/EntryAndMetadata.java
+++ 
b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/EntryAndMetadata.java
@@ -45,7 +45,7 @@ public class EntryAndMetadata implements Entry {
     }
 
     @VisibleForTesting
-    static EntryAndMetadata create(final Entry entry) {
+    public static EntryAndMetadata create(final Entry entry) {
         return create(entry, 
Commands.peekAndCopyMessageMetadata(entry.getDataBuffer(), "", -1));
     }
 
diff --git 
a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/PersistentStickyKeyDispatcherMultipleConsumersTest.java
 
b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/PersistentStickyKeyDispatcherMultipleConsumersTest.java
index bf3a79133b0..9b7c98cc30e 100644
--- 
a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/PersistentStickyKeyDispatcherMultipleConsumersTest.java
+++ 
b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/PersistentStickyKeyDispatcherMultipleConsumersTest.java
@@ -27,10 +27,8 @@ import static org.mockito.Mockito.anyInt;
 import static org.mockito.Mockito.anyList;
 import static org.mockito.Mockito.anyLong;
 import static org.mockito.Mockito.anySet;
-import static org.mockito.Mockito.argThat;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -40,21 +38,27 @@ import static org.testng.Assert.assertTrue;
 import static org.testng.Assert.fail;
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
-import io.netty.channel.ChannelPromise;
 import io.netty.channel.EventLoopGroup;
+import io.netty.util.concurrent.EventExecutor;
+import io.netty.util.concurrent.Future;
+import io.netty.util.concurrent.SucceededFuture;
 import java.lang.reflect.Field;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.Queue;
+import java.util.Objects;
 import java.util.Set;
-import java.util.concurrent.ConcurrentLinkedQueue;
-import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.TreeSet;
+import java.util.concurrent.ConcurrentSkipListSet;
+import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
 import org.apache.bookkeeper.common.util.OrderedExecutor;
+import org.apache.bookkeeper.mledger.AsyncCallbacks;
 import org.apache.bookkeeper.mledger.Entry;
 import org.apache.bookkeeper.mledger.Position;
 import org.apache.bookkeeper.mledger.impl.EntryImpl;
@@ -64,10 +68,13 @@ import org.apache.pulsar.broker.PulsarService;
 import org.apache.pulsar.broker.ServiceConfiguration;
 import org.apache.pulsar.broker.service.BrokerService;
 import org.apache.pulsar.broker.service.Consumer;
+import org.apache.pulsar.broker.service.EntryAndMetadata;
 import org.apache.pulsar.broker.service.EntryBatchIndexesAcks;
 import org.apache.pulsar.broker.service.EntryBatchSizes;
 import org.apache.pulsar.broker.service.RedeliveryTracker;
 import org.apache.pulsar.broker.service.StickyKeyConsumerSelector;
+import org.apache.pulsar.broker.service.TransportCnx;
+import org.apache.pulsar.broker.service.plugin.EntryFilterProvider;
 import org.apache.pulsar.common.api.proto.KeySharedMeta;
 import org.apache.pulsar.common.api.proto.KeySharedMode;
 import org.apache.pulsar.common.api.proto.MessageMetadata;
@@ -77,6 +84,7 @@ import org.apache.pulsar.common.protocol.Markers;
 import org.awaitility.Awaitility;
 import org.mockito.ArgumentCaptor;
 import org.testng.Assert;
+import org.testng.annotations.AfterMethod;
 import org.testng.annotations.BeforeMethod;
 import org.testng.annotations.Test;
 
@@ -90,13 +98,14 @@ public class 
PersistentStickyKeyDispatcherMultipleConsumersTest {
     private PersistentTopic topicMock;
     private PersistentSubscription subscriptionMock;
     private ServiceConfiguration configMock;
-    private ChannelPromise channelMock;
+    private Future<Void> succeededFuture;
     private OrderedExecutor orderedExecutor;
 
     private PersistentStickyKeyDispatcherMultipleConsumers 
persistentDispatcher;
 
     final String topicName = "persistent://public/default/testTopic";
     final String subscriptionName = "testSubscription";
+    private AtomicInteger consumerMockAvailablePermits;
 
     @BeforeMethod
     public void setup() throws Exception {
@@ -106,12 +115,17 @@ public class 
PersistentStickyKeyDispatcherMultipleConsumersTest {
         
doReturn(true).when(configMock).isSubscriptionKeySharedUseConsistentHashing();
         
doReturn(1).when(configMock).getSubscriptionKeySharedConsistentHashingReplicaPoints();
         
doReturn(true).when(configMock).isDispatcherDispatchMessagesInSubscriptionThread();
+        doReturn(false).when(configMock).isAllowOverrideEntryFilters();
 
         pulsarMock = mock(PulsarService.class);
         doReturn(configMock).when(pulsarMock).getConfiguration();
 
+        EntryFilterProvider mockEntryFilterProvider = 
mock(EntryFilterProvider.class);
+        
when(mockEntryFilterProvider.getBrokerEntryFilters()).thenReturn(Collections.emptyList());
+
         brokerMock = mock(BrokerService.class);
         doReturn(pulsarMock).when(brokerMock).pulsar();
+        
when(brokerMock.getEntryFilterProvider()).thenReturn(mockEntryFilterProvider);
 
         HierarchyTopicPolicies topicPolicies = new HierarchyTopicPolicies();
         topicPolicies.getMaxConsumersPerSubscription().updateBrokerValue(0);
@@ -122,7 +136,7 @@ public class 
PersistentStickyKeyDispatcherMultipleConsumersTest {
         EventLoopGroup eventLoopGroup = mock(EventLoopGroup.class);
         doReturn(eventLoopGroup).when(brokerMock).executor();
         doAnswer(invocation -> {
-            orderedExecutor.execute(((Runnable)invocation.getArguments()[0]));
+            orderedExecutor.execute(invocation.getArgument(0, Runnable.class));
             return null;
         }).when(eventLoopGroup).execute(any(Runnable.class));
 
@@ -135,12 +149,36 @@ public class 
PersistentStickyKeyDispatcherMultipleConsumersTest {
         doReturn(null).when(cursorMock).getLastIndividualDeletedRange();
         doReturn(subscriptionName).when(cursorMock).getName();
 
-        consumerMock = mock(Consumer.class);
-        channelMock = mock(ChannelPromise.class);
+        consumerMock = createMockConsumer();
+        EventExecutor eventExecutor = mock(EventExecutor.class);
+        doAnswer(invocation -> {
+            orderedExecutor.execute(invocation.getArgument(0, Runnable.class));
+            return null;
+        }).when(eventExecutor).execute(any(Runnable.class));
+        doReturn(false).when(eventExecutor).inEventLoop();
+        succeededFuture = new SucceededFuture<>(eventExecutor, null);
         doReturn("consumer1").when(consumerMock).consumerName();
-        doReturn(1000).when(consumerMock).getAvailablePermits();
+        consumerMockAvailablePermits = new AtomicInteger(1000);
+        doAnswer(invocation -> 
consumerMockAvailablePermits.get()).when(consumerMock).getAvailablePermits();
         doReturn(true).when(consumerMock).isWritable();
-        doReturn(channelMock).when(consumerMock).sendMessages(
+        mockSendMessages(consumerMock, null);
+
+        subscriptionMock = mock(PersistentSubscription.class);
+        when(subscriptionMock.getTopic()).thenReturn(topicMock);
+        persistentDispatcher = new 
PersistentStickyKeyDispatcherMultipleConsumers(
+                topicMock, cursorMock, subscriptionMock, configMock,
+                new 
KeySharedMeta().setKeySharedMode(KeySharedMode.AUTO_SPLIT));
+    }
+
+    private void mockSendMessages(Consumer consumerMock, 
java.util.function.Consumer<List<Entry>> entryConsumer) {
+        doAnswer(invocation -> {
+            List<Entry> entries = invocation.getArgument(0);
+            if (entryConsumer != null) {
+                entryConsumer.accept(entries);
+            }
+            entries.stream().filter(Objects::nonNull).forEach(Entry::release);
+            return succeededFuture;
+        }).when(consumerMock).sendMessages(
                 anyList(),
                 any(EntryBatchSizes.class),
                 any(EntryBatchIndexesAcks.class),
@@ -149,16 +187,25 @@ public class 
PersistentStickyKeyDispatcherMultipleConsumersTest {
                 anyLong(),
                 any(RedeliveryTracker.class)
         );
+    }
 
-        subscriptionMock = mock(PersistentSubscription.class);
-        persistentDispatcher = new 
PersistentStickyKeyDispatcherMultipleConsumers(
-                topicMock, cursorMock, subscriptionMock, configMock,
-                new 
KeySharedMeta().setKeySharedMode(KeySharedMode.AUTO_SPLIT));
+    protected static Consumer createMockConsumer() {
+        Consumer consumerMock = mock(Consumer.class);
+        TransportCnx transportCnx = mock(TransportCnx.class);
+        doReturn(transportCnx).when(consumerMock).cnx();
+        doReturn(true).when(transportCnx).isActive();
+        doReturn(100).when(consumerMock).getMaxUnackedMessages();
+        doReturn(1).when(consumerMock).getAvgMessagesPerEntry();
+        return consumerMock;
     }
 
+    @AfterMethod(alwaysRun = true)
     public void cleanup() {
+        if (persistentDispatcher != null && !persistentDispatcher.isClosed()) {
+            persistentDispatcher.close();
+        }
         if (orderedExecutor != null) {
-            orderedExecutor.shutdown();
+            orderedExecutor.shutdownNow();
             orderedExecutor = null;
         }
     }
@@ -166,8 +213,8 @@ public class 
PersistentStickyKeyDispatcherMultipleConsumersTest {
     @Test(timeOut = 10000)
     public void testAddConsumerWhenClosed() throws Exception {
         persistentDispatcher.close().get();
-        Consumer consumer = mock(Consumer.class);
-        persistentDispatcher.addConsumer(consumer);
+        Consumer consumer = createMockConsumer();
+        persistentDispatcher.addConsumer(consumer).join();
         verify(consumer, times(1)).disconnect();
         assertEquals(0, persistentDispatcher.getConsumers().size());
         
assertTrue(persistentDispatcher.getSelector().getConsumerKeyHashRanges().isEmpty());
@@ -180,19 +227,19 @@ public class 
PersistentStickyKeyDispatcherMultipleConsumersTest {
                 topicMock, cursorMock, subscriptionMock, configMock,
                 new 
KeySharedMeta().setKeySharedMode(KeySharedMode.AUTO_SPLIT));
 
-        Consumer consumer0 = mock(Consumer.class);
+        Consumer consumer0 = createMockConsumer();
         when(consumer0.consumerName()).thenReturn("c0-1");
-        Consumer consumer1 = mock(Consumer.class);
+        Consumer consumer1 = createMockConsumer();
         when(consumer1.consumerName()).thenReturn("c1");
-        Consumer consumer2 = mock(Consumer.class);
+        Consumer consumer2 = createMockConsumer();
         when(consumer2.consumerName()).thenReturn("c2");
-        Consumer consumer3 = mock(Consumer.class);
+        Consumer consumer3 = createMockConsumer();
         when(consumer3.consumerName()).thenReturn("c3");
-        Consumer consumer4 = mock(Consumer.class);
+        Consumer consumer4 = createMockConsumer();
         when(consumer4.consumerName()).thenReturn("c4");
-        Consumer consumer5 = mock(Consumer.class);
+        Consumer consumer5 = createMockConsumer();
         when(consumer5.consumerName()).thenReturn("c5");
-        Consumer consumer6 = mock(Consumer.class);
+        Consumer consumer6 = createMockConsumer();
         when(consumer6.consumerName()).thenReturn("c6");
 
         
when(cursorMock.getNumberOfEntriesSinceFirstNotAckedMessage()).thenReturn(100L);
@@ -255,7 +302,7 @@ public class 
PersistentStickyKeyDispatcherMultipleConsumersTest {
     @Test
     public void testSendMarkerMessage() {
         try {
-            persistentDispatcher.addConsumer(consumerMock);
+            persistentDispatcher.addConsumer(consumerMock).join();
             persistentDispatcher.consumerFlow(consumerMock, 1000);
         } catch (Exception e) {
             fail("Failed to add mock consumer", e);
@@ -264,14 +311,16 @@ public class 
PersistentStickyKeyDispatcherMultipleConsumersTest {
         List<Entry> entries = new ArrayList<>();
         ByteBuf markerMessage = 
Markers.newReplicatedSubscriptionsSnapshotRequest("testSnapshotId", 
"testSourceCluster");
         entries.add(EntryImpl.create(1, 1, markerMessage));
-        entries.add(EntryImpl.create(1, 2, createMessage("message1", 1)));
-        entries.add(EntryImpl.create(1, 3, createMessage("message2", 2)));
-        entries.add(EntryImpl.create(1, 4, createMessage("message3", 3)));
-        entries.add(EntryImpl.create(1, 5, createMessage("message4", 4)));
-        entries.add(EntryImpl.create(1, 6, createMessage("message5", 5)));
+        markerMessage.release();
+        entries.add(createEntry(1, 2, "message1", 1));
+        entries.add(createEntry(1, 3, "message2", 2));
+        entries.add(createEntry(1, 4, "message3", 3));
+        entries.add(createEntry(1, 5, "message4", 4));
+        entries.add(createEntry(1, 6, "message5", 5));
 
         try {
-            persistentDispatcher.readEntriesComplete(entries, 
PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal);
+            persistentDispatcher.readEntriesComplete(copyEntries(entries),
+                    
PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal);
         } catch (Exception e) {
             fail("Failed to readEntriesComplete.", e);
         }
@@ -291,243 +340,264 @@ public class 
PersistentStickyKeyDispatcherMultipleConsumersTest {
             List<Integer> allTotalMessagesCaptor = 
totalMessagesCaptor.getAllValues();
             Assert.assertEquals(allTotalMessagesCaptor.get(0).intValue(), 5);
         });
+
+        entries.forEach(Entry::release);
+    }
+
+    private static List<Entry> copyEntries(List<Entry> entries) {
+        return entries.stream().map(entry -> EntryImpl.create((EntryImpl) 
entry))
+                .collect(Collectors.toList());
     }
 
     @Test(timeOut = 10000)
     public void testSendMessage() {
         KeySharedMeta keySharedMeta = new 
KeySharedMeta().setKeySharedMode(KeySharedMode.STICKY);
-        PersistentStickyKeyDispatcherMultipleConsumers persistentDispatcher = 
new PersistentStickyKeyDispatcherMultipleConsumers(
+        PersistentStickyKeyDispatcherMultipleConsumers
+                persistentDispatcher = new 
PersistentStickyKeyDispatcherMultipleConsumers(
                 topicMock, cursorMock, subscriptionMock, configMock, 
keySharedMeta);
         try {
             keySharedMeta.addHashRange()
                     .setStart(0)
                     .setEnd(9);
 
-            Consumer consumerMock = mock(Consumer.class);
+            Consumer consumerMock = createMockConsumer();
             doReturn(keySharedMeta).when(consumerMock).getKeySharedMeta();
-            persistentDispatcher.addConsumer(consumerMock);
+            mockSendMessages(consumerMock, null);
+            persistentDispatcher.addConsumer(consumerMock).join();
             persistentDispatcher.consumerFlow(consumerMock, 1000);
         } catch (Exception e) {
             fail("Failed to add mock consumer", e);
         }
 
         List<Entry> entries = new ArrayList<>();
-        entries.add(EntryImpl.create(1, 1, createMessage("message1", 1)));
-        entries.add(EntryImpl.create(1, 2, createMessage("message2", 2)));
+        entries.add(createEntry(1, 1, "message1", 1));
+        entries.add(createEntry(1, 2, "message2", 2));
 
         try {
             //Should success,see issue #8960
-            persistentDispatcher.readEntriesComplete(entries, 
PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal);
+            persistentDispatcher.readEntriesComplete(copyEntries(entries),
+                    
PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal);
         } catch (Exception e) {
             fail("Failed to readEntriesComplete.", e);
         }
+
+        entries.forEach(Entry::release);
     }
 
     @Test
-    public void testSkipRedeliverTemporally() {
-        final Consumer slowConsumerMock = mock(Consumer.class);
-        final ChannelPromise slowChannelMock = mock(ChannelPromise.class);
-        // add entries to redeliver and read target
+    public void testSkipRedeliverTemporally() throws InterruptedException {
+        // add first consumer
+        persistentDispatcher.addConsumer(consumerMock).join();
+        // add slow consumer
+        final Consumer slowConsumerMock = createMockConsumer();
+        doReturn("consumer2").when(slowConsumerMock).consumerName();
+        AtomicInteger slowConsumerAvailablePermits = new AtomicInteger(0);
+        doAnswer(invocation -> {
+            return slowConsumerAvailablePermits.get();
+        }).when(slowConsumerMock).getAvailablePermits();
+        persistentDispatcher.addConsumer(slowConsumerMock).join();
+
+        StickyKeyConsumerSelector selector = 
persistentDispatcher.getSelector();
+        String keyForConsumer = generateKeyForConsumer(selector, consumerMock);
+        String keyForSlowConsumer = generateKeyForConsumer(selector, 
slowConsumerMock);
+
+        Set<Position> alreadySent = new ConcurrentSkipListSet<>();
+
+        final List<Entry> allEntries = new ArrayList<>();
+        allEntries.add(createEntry(1, 1, "message1", 1, keyForSlowConsumer));
+        allEntries.add(createEntry(1, 2, "message2", 2, keyForSlowConsumer));
+        allEntries.add(createEntry(1, 3, "message3", 3, keyForConsumer));
+
+        // add first entry to redeliver initially
         final List<Entry> redeliverEntries = new ArrayList<>();
-        redeliverEntries.add(EntryImpl.create(1, 1, createMessage("message1", 
1, "key1")));
-        final List<Entry> readEntries = new ArrayList<>();
-        readEntries.add(EntryImpl.create(1, 2, createMessage("message2", 2, 
"key1")));
-        readEntries.add(EntryImpl.create(1, 3, createMessage("message3", 3, 
"key2")));
+        redeliverEntries.add(allEntries.get(0));
 
         try {
-            Field totalAvailablePermitsField = 
PersistentDispatcherMultipleConsumers.class.getDeclaredField("totalAvailablePermits");
+            Field totalAvailablePermitsField =
+                    
PersistentDispatcherMultipleConsumers.class.getDeclaredField("totalAvailablePermits");
             totalAvailablePermitsField.setAccessible(true);
             totalAvailablePermitsField.set(persistentDispatcher, 1000);
-
-            doAnswer(invocationOnMock -> {
-                ((PersistentStickyKeyDispatcherMultipleConsumers) 
invocationOnMock.getArgument(2))
-                        .readEntriesComplete(readEntries, 
PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal);
-                return null;
-            }).when(cursorMock).asyncReadEntriesOrWait(
-                    anyInt(), anyLong(), 
any(PersistentStickyKeyDispatcherMultipleConsumers.class),
-                    
eq(PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal), any());
         } catch (Exception e) {
             fail("Failed to set to field", e);
         }
 
-        // Create 2Consumers
-        try {
-            doReturn("consumer2").when(slowConsumerMock).consumerName();
-            // Change slowConsumer availablePermits to 0 and back to normal
-            when(slowConsumerMock.getAvailablePermits())
-                    .thenReturn(0)
-                    .thenReturn(1);
-            doReturn(true).when(slowConsumerMock).isWritable();
-            doReturn(slowChannelMock).when(slowConsumerMock).sendMessages(
-                    anyList(),
-                    any(EntryBatchSizes.class),
-                    any(EntryBatchIndexesAcks.class),
-                    anyInt(),
-                    anyLong(),
-                    anyLong(),
-                    any(RedeliveryTracker.class)
-            );
-
-            persistentDispatcher.addConsumer(consumerMock);
-            persistentDispatcher.addConsumer(slowConsumerMock);
-        } catch (Exception e) {
-            fail("Failed to add mock consumer", e);
-        }
+        // Mock Cursor#asyncReplayEntries
+        doAnswer(invocationOnMock -> {
+            Set<Position> positionsArg = invocationOnMock.getArgument(0);
+            Set<Position> positions = new TreeSet<>(positionsArg);
+            List<Entry> entries = allEntries.stream()
+                    .filter(entry -> entry.getLedgerId() != -1 && 
positions.contains(entry.getPosition()))
+                    .toList();
+            AsyncCallbacks.ReadEntriesCallback callback = 
invocationOnMock.getArgument(1);
+            Object ctx = invocationOnMock.getArgument(2);
+            callback.readEntriesComplete(copyEntries(entries), ctx);
+            return Collections.emptySet();
+        }).when(cursorMock).asyncReplayEntries(anySet(), any(), any(), 
anyBoolean());
 
-        // run 
PersistentStickyKeyDispatcherMultipleConsumers#sendMessagesToConsumers
-        // run readMoreEntries internally (and skip internally)
-        // Change slowConsumer availablePermits to 1
-        // run 
PersistentStickyKeyDispatcherMultipleConsumers#sendMessagesToConsumers 
internally
-        // and then stop to dispatch to slowConsumer
-        if 
(persistentDispatcher.sendMessagesToConsumers(PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal,
-                redeliverEntries, true)) {
-            persistentDispatcher.readMoreEntriesAsync();
-        }
+        doAnswer(invocationOnMock -> {
+            int maxEntries = invocationOnMock.getArgument(0);
+            AsyncCallbacks.ReadEntriesCallback callback = 
invocationOnMock.getArgument(2);
+            List<Entry> entries = allEntries.stream()
+                    .filter(entry -> entry.getLedgerId() != -1 && 
!alreadySent.contains(entry.getPosition()))
+                    .limit(maxEntries)
+                    .toList();
+            Object ctx = invocationOnMock.getArgument(3);
+            callback.readEntriesComplete(copyEntries(entries), ctx);
+            return null;
+        }).when(cursorMock).asyncReadEntriesOrWait(anyInt(), anyLong(), any(), 
any(), any());
+
+        doReturn(true).when(slowConsumerMock).isWritable();
+        CountDownLatch message3Sent = new CountDownLatch(1);
+        mockSendMessages(consumerMock, entries -> {
+            entries.forEach(entry -> {
+                alreadySent.add(entry.getPosition());
+            });
+            boolean message3Found = entries.stream()
+                    .anyMatch(entry -> entry.getLedgerId() == 1 && 
entry.getEntryId() == 3);
+            if (message3Found) {
+                message3Sent.countDown();
+            }
+        });
+        CountDownLatch slowConsumerMessagesSent = new CountDownLatch(2);
+        mockSendMessages(slowConsumerMock, entries -> {
+            entries.forEach(entry -> {
+                alreadySent.add(entry.getPosition());
+                slowConsumerMessagesSent.countDown();
+            });
+        });
 
-        Awaitility.await().untilAsserted(() -> {
-            verify(consumerMock, times(1)).sendMessages(
-                    argThat(arg -> {
-                        assertEquals(arg.size(), 1);
-                        Entry entry = arg.get(0);
-                        assertEquals(entry.getLedgerId(), 1);
-                        assertEquals(entry.getEntryId(), 3);
-                        return true;
-                    }),
-                    any(EntryBatchSizes.class),
-                    any(EntryBatchIndexesAcks.class),
-                    anyInt(),
-                    anyLong(),
-                    anyLong(),
-                    any(RedeliveryTracker.class)
-            );
+        // add entries to redeliver
+        redeliverEntries.forEach(entry -> {
+            // calculate hash
+            EntryAndMetadata entryAndMetadata = EntryAndMetadata.create(entry);
+            int stickyKeyHash = 
StickyKeyConsumerSelector.makeStickyKeyHash(entryAndMetadata.getStickyKey());
+            // add to redeliver
+            persistentDispatcher.addMessageToReplay(entry.getLedgerId(), 
entry.getEntryId(), stickyKeyHash);
         });
-        verify(slowConsumerMock, times(0)).sendMessages(
-                anyList(),
-                any(EntryBatchSizes.class),
-                any(EntryBatchIndexesAcks.class),
-                anyInt(),
-                anyLong(),
-                anyLong(),
-                any(RedeliveryTracker.class)
-        );
+
+        // trigger readMoreEntries, will handle redelivery logic and skip slow 
consumer
+        persistentDispatcher.readMoreEntriesAsync();
+
+        assertTrue(message3Sent.await(5, TimeUnit.SECONDS));
+
+        // verify that slow consumer messages are not sent before message3 to 
"consumer"
+        assertEquals(slowConsumerMessagesSent.getCount(), 2);
+
+        // set permits to 2
+        slowConsumerAvailablePermits.set(2);
+
+        // now wait for slow consumer messages since there are permits
+        assertTrue(slowConsumerMessagesSent.await(5, TimeUnit.SECONDS));
+
+        allEntries.forEach(Entry::release);
     }
 
     @Test(timeOut = 30000)
     public void testMessageRedelivery() throws Exception {
-        final Queue<Position> actualEntriesToConsumer1 = new 
ConcurrentLinkedQueue<>();
-        final Queue<Position> actualEntriesToConsumer2 = new 
ConcurrentLinkedQueue<>();
+        final List<Position> actualEntriesToConsumer1 = new 
CopyOnWriteArrayList<>();
+        final List<Position> actualEntriesToConsumer2 = new 
CopyOnWriteArrayList<>();
 
-        final Queue<Position> expectedEntriesToConsumer1 = new 
ConcurrentLinkedQueue<>();
-        expectedEntriesToConsumer1.add(PositionImpl.get(1, 1));
-        final Queue<Position> expectedEntriesToConsumer2 = new 
ConcurrentLinkedQueue<>();
-        expectedEntriesToConsumer2.add(PositionImpl.get(1, 2));
-        expectedEntriesToConsumer2.add(PositionImpl.get(1, 3));
+        final List<Position> expectedEntriesToConsumer1 = new 
CopyOnWriteArrayList<>();
+        final List<Position> expectedEntriesToConsumer2 = new 
CopyOnWriteArrayList<>();
 
-        final AtomicInteger remainingEntriesNum = new AtomicInteger(
-                expectedEntriesToConsumer1.size() + 
expectedEntriesToConsumer2.size());
-
-        // Messages with key1 are routed to consumer1 and messages with key2 
are routed to consumer2
-        final List<Entry> allEntries = new ArrayList<>();
-        allEntries.add(EntryImpl.create(1, 1, createMessage("message1", 1, 
"key2")));
-        allEntries.add(EntryImpl.create(1, 2, createMessage("message2", 2, 
"key1")));
-        allEntries.add(EntryImpl.create(1, 3, createMessage("message3", 3, 
"key1")));
-        allEntries.forEach(entry -> ((EntryImpl) entry).retain());
+        final CountDownLatch remainingEntriesNum = new CountDownLatch(3);
 
-        final List<Entry> redeliverEntries = new ArrayList<>();
-        redeliverEntries.add(allEntries.get(0)); // message1
-        final List<Entry> readEntries = new ArrayList<>();
-        readEntries.add(allEntries.get(2)); // message3
-
-        final Consumer consumer1 = mock(Consumer.class);
+        final Consumer consumer1 = createMockConsumer();
         doReturn("consumer1").when(consumer1).consumerName();
         // Change availablePermits of consumer1 to 0 and then back to normal
         when(consumer1.getAvailablePermits()).thenReturn(0).thenReturn(10);
         doReturn(true).when(consumer1).isWritable();
-        doAnswer(invocationOnMock -> {
-            @SuppressWarnings("unchecked")
-            List<Entry> entries = (List<Entry>) 
invocationOnMock.getArgument(0);
+        mockSendMessages(consumer1, entries -> {
             for (Entry entry : entries) {
-                remainingEntriesNum.decrementAndGet();
                 actualEntriesToConsumer1.add(entry.getPosition());
+                remainingEntriesNum.countDown();
             }
-            return channelMock;
-        }).when(consumer1).sendMessages(anyList(), any(EntryBatchSizes.class), 
any(EntryBatchIndexesAcks.class),
-                anyInt(), anyLong(), anyLong(), any(RedeliveryTracker.class));
+        });
 
-        final Consumer consumer2 = mock(Consumer.class);
+        final Consumer consumer2 = createMockConsumer();
         doReturn("consumer2").when(consumer2).consumerName();
         when(consumer2.getAvailablePermits()).thenReturn(10);
         doReturn(true).when(consumer2).isWritable();
-        doAnswer(invocationOnMock -> {
-            @SuppressWarnings("unchecked")
-            List<Entry> entries = (List<Entry>) 
invocationOnMock.getArgument(0);
+        mockSendMessages(consumer2, entries -> {
             for (Entry entry : entries) {
-                remainingEntriesNum.decrementAndGet();
                 actualEntriesToConsumer2.add(entry.getPosition());
+                remainingEntriesNum.countDown();
             }
-            return channelMock;
-        }).when(consumer2).sendMessages(anyList(), any(EntryBatchSizes.class), 
any(EntryBatchIndexesAcks.class),
-                anyInt(), anyLong(), anyLong(), any(RedeliveryTracker.class));
+        });
 
-        persistentDispatcher.addConsumer(consumer1);
-        persistentDispatcher.addConsumer(consumer2);
+        persistentDispatcher.addConsumer(consumer1).join();
+        persistentDispatcher.addConsumer(consumer2).join();
 
         final Field totalAvailablePermitsField = 
PersistentDispatcherMultipleConsumers.class
                 .getDeclaredField("totalAvailablePermits");
         totalAvailablePermitsField.setAccessible(true);
         totalAvailablePermitsField.set(persistentDispatcher, 1000);
 
-        final Field redeliveryMessagesField = 
PersistentDispatcherMultipleConsumers.class
-                .getDeclaredField("redeliveryMessages");
-        redeliveryMessagesField.setAccessible(true);
-        MessageRedeliveryController redeliveryMessages = 
(MessageRedeliveryController) redeliveryMessagesField
-                .get(persistentDispatcher);
-        redeliveryMessages.add(allEntries.get(0).getLedgerId(), 
allEntries.get(0).getEntryId(),
-                getStickyKeyHash(allEntries.get(0))); // message1
-        redeliveryMessages.add(allEntries.get(1).getLedgerId(), 
allEntries.get(1).getEntryId(),
-                getStickyKeyHash(allEntries.get(1))); // message2
+        StickyKeyConsumerSelector selector = 
persistentDispatcher.getSelector();
+
+        String keyForConsumer1 = generateKeyForConsumer(selector, consumer1);
+        String keyForConsumer2 = generateKeyForConsumer(selector, consumer2);
+
+        // Messages with key1 are routed to consumer1 and messages with key2 
are routed to consumer2
+        final List<Entry> allEntries = new ArrayList<>();
+        allEntries.add(createEntry(1, 1, "message1", 1, keyForConsumer1));
+        allEntries.add(createEntry(1, 2, "message2", 2, keyForConsumer1));
+        allEntries.add(createEntry(1, 3, "message3", 3, keyForConsumer2));
+
+        // add first entry to redeliver initially
+        final List<Entry> redeliverEntries = new ArrayList<>();
+        redeliverEntries.add(allEntries.get(0)); // message1
+
+        expectedEntriesToConsumer1.add(allEntries.get(0).getPosition());
+        expectedEntriesToConsumer1.add(allEntries.get(1).getPosition());
+        expectedEntriesToConsumer2.add(allEntries.get(2).getPosition());
 
         // Mock Cursor#asyncReplayEntries
         doAnswer(invocationOnMock -> {
-            @SuppressWarnings("unchecked")
-            Set<Position> positions = (Set<Position>) 
invocationOnMock.getArgument(0);
-            List<Entry> entries = allEntries.stream().filter(entry -> 
positions.contains(entry.getPosition()))
+            Set<Position> positionsArg = invocationOnMock.getArgument(0);
+            Set<Position> positions = new TreeSet<>(positionsArg);
+            Set<Position> alreadyReceived = new TreeSet<>();
+            alreadyReceived.addAll(actualEntriesToConsumer1);
+            alreadyReceived.addAll(actualEntriesToConsumer2);
+            List<Entry> entries = allEntries.stream().filter(entry -> 
entry.getLedgerId() != -1
+                            && positions.contains(entry.getPosition())
+                            && !alreadyReceived.contains(entry.getPosition()))
                     .collect(Collectors.toList());
-            if (!entries.isEmpty()) {
-                ((PersistentStickyKeyDispatcherMultipleConsumers) 
invocationOnMock.getArgument(1))
-                        .readEntriesComplete(entries, 
PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Replay);
-            }
-            return Collections.emptySet();
-        }).when(cursorMock).asyncReplayEntries(anySet(), 
any(PersistentStickyKeyDispatcherMultipleConsumers.class),
-                
eq(PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Replay), 
anyBoolean());
+            AsyncCallbacks.ReadEntriesCallback callback = 
invocationOnMock.getArgument(1);
+            Object ctx = invocationOnMock.getArgument(2);
+            callback.readEntriesComplete(copyEntries(entries), ctx);
+            return alreadyReceived;
+        }).when(cursorMock).asyncReplayEntries(anySet(), any(), any(), 
anyBoolean());
 
         // Mock Cursor#asyncReadEntriesOrWait
-        AtomicBoolean asyncReadEntriesOrWaitCalled = new AtomicBoolean();
         doAnswer(invocationOnMock -> {
-            if (asyncReadEntriesOrWaitCalled.compareAndSet(false, true)) {
-                ((PersistentStickyKeyDispatcherMultipleConsumers) 
invocationOnMock.getArgument(2))
-                        .readEntriesComplete(readEntries, 
PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal);
-            } else {
-                ((PersistentStickyKeyDispatcherMultipleConsumers) 
invocationOnMock.getArgument(2))
-                        .readEntriesComplete(Collections.emptyList(), 
PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal);
-            }
+            int maxEntries = invocationOnMock.getArgument(0);
+            Set<Position> alreadyReceived = new TreeSet<>();
+            alreadyReceived.addAll(actualEntriesToConsumer1);
+            alreadyReceived.addAll(actualEntriesToConsumer2);
+            List<Entry> entries = allEntries.stream()
+                    .filter(entry -> entry.getLedgerId() != -1 && 
!alreadyReceived.contains(entry.getPosition()))
+                    .limit(maxEntries)
+                    .collect(Collectors.toList());
+            AsyncCallbacks.ReadEntriesCallback callback = 
invocationOnMock.getArgument(2);
+            Object ctx = invocationOnMock.getArgument(3);
+            callback.readEntriesComplete(copyEntries(entries), ctx);
             return null;
-        }).when(cursorMock).asyncReadEntriesOrWait(anyInt(), anyLong(),
-                any(PersistentStickyKeyDispatcherMultipleConsumers.class),
-                
eq(PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal), any());
-
-        // (1) Run sendMessagesToConsumers
-        // (2) Attempts to send message1 to consumer1 but skipped because 
availablePermits is 0
-        // (3) Change availablePermits of consumer1 to 10
-        // (4) Run readMoreEntries internally
-        // (5) Run sendMessagesToConsumers internally
-        // (6) Attempts to send message3 to consumer2 but skipped because 
redeliveryMessages contains message2
-        
persistentDispatcher.sendMessagesToConsumers(PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Replay,
-                redeliverEntries, true);
-        while (remainingEntriesNum.get() > 0) {
-            // (7) Run readMoreEntries and resend message1 to consumer1 and 
message2-3 to consumer2
-            persistentDispatcher.readMoreEntries();
-        }
+        }).when(cursorMock).asyncReadEntriesOrWait(anyInt(), anyLong(), any(), 
any(), any());
+
+        // add entries to redeliver
+        redeliverEntries.forEach(entry -> {
+            // calculate hash
+            EntryAndMetadata entryAndMetadata = EntryAndMetadata.create(entry);
+            int stickyKeyHash = 
StickyKeyConsumerSelector.makeStickyKeyHash(entryAndMetadata.getStickyKey());
+            // add to redeliver
+            persistentDispatcher.addMessageToReplay(entry.getLedgerId(), 
entry.getEntryId(), stickyKeyHash);
+        });
+
+        // trigger logic to read entries, includes redelivery logic
+        persistentDispatcher.readMoreEntriesAsync();
+
+        assertTrue(remainingEntriesNum.await(5, TimeUnit.SECONDS));
 
         
assertThat(actualEntriesToConsumer1).containsExactlyElementsOf(expectedEntriesToConsumer1);
         
assertThat(actualEntriesToConsumer2).containsExactlyElementsOf(expectedEntriesToConsumer2);
@@ -535,22 +605,39 @@ public class 
PersistentStickyKeyDispatcherMultipleConsumersTest {
         allEntries.forEach(entry -> entry.release());
     }
 
-    private ByteBuf createMessage(String message, int sequenceId) {
-        return createMessage(message, sequenceId, "testKey");
+    private String generateKeyForConsumer(StickyKeyConsumerSelector selector, 
Consumer consumer) {
+        int i = 0;
+        while (!Thread.currentThread().isInterrupted()) {
+            String key = "key" + i++;
+            Consumer selectedConsumer = selector.select(key.getBytes(UTF_8));
+            if (selectedConsumer == consumer) {
+                return key;
+            }
+        }
+        return null;
     }
 
-    private ByteBuf createMessage(String message, int sequenceId, String key) {
+    private EntryImpl createEntry(long ledgerId, long entryId, String message, 
long sequenceId) {
+        return createEntry(ledgerId, entryId, message, sequenceId, "testKey");
+    }
+
+    private EntryImpl createEntry(long ledgerId, long entryId, String message, 
long sequenceId, String key) {
+        ByteBuf data = createMessage(message, sequenceId, key);
+        EntryImpl entry = EntryImpl.create(ledgerId, entryId, data);
+        data.release();
+        return entry;
+    }
+
+    private ByteBuf createMessage(String message, long sequenceId, String key) 
{
         MessageMetadata messageMetadata = new MessageMetadata()
                 .setSequenceId(sequenceId)
                 .setProducerName("testProducer")
                 .setPartitionKey(key)
                 .setPartitionKeyB64Encoded(false)
                 .setPublishTime(System.currentTimeMillis());
-        return serializeMetadataAndPayload(Commands.ChecksumType.Crc32c, 
messageMetadata, Unpooled.copiedBuffer(message.getBytes(UTF_8)));
-    }
-
-    private int getStickyKeyHash(Entry entry) {
-        byte[] stickyKey = Commands.peekStickyKey(entry.getDataBuffer(), 
topicName, subscriptionName);
-        return StickyKeyConsumerSelector.makeStickyKeyHash(stickyKey);
+        ByteBuf payload = Unpooled.copiedBuffer(message.getBytes(UTF_8));
+        ByteBuf byteBuf = 
serializeMetadataAndPayload(Commands.ChecksumType.Crc32c, messageMetadata, 
payload);
+        payload.release();
+        return byteBuf;
     }
 }
diff --git 
a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/AcknowledgementsGroupingTrackerTest.java
 
b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/AcknowledgementsGroupingTrackerTest.java
index 1d1a6f85bfd..efa398afba6 100644
--- 
a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/AcknowledgementsGroupingTrackerTest.java
+++ 
b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/AcknowledgementsGroupingTrackerTest.java
@@ -18,6 +18,7 @@
  */
 package org.apache.pulsar.client.impl;
 
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
@@ -37,15 +38,16 @@ import java.util.Collections;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import org.apache.pulsar.client.api.MessageId;
 import org.apache.pulsar.client.api.MessageIdAdv;
 import org.apache.pulsar.client.impl.conf.ClientConfigurationData;
 import org.apache.pulsar.client.impl.conf.ConsumerConfigurationData;
 import org.apache.pulsar.client.util.TimedCompletableFuture;
 import org.apache.pulsar.common.api.proto.CommandAck.AckType;
-import org.apache.pulsar.common.util.collections.ConcurrentBitSetRecyclable;
 import org.apache.pulsar.common.util.collections.ConcurrentOpenHashMap;
 import org.apache.pulsar.common.api.proto.ProtocolVersion;
+import org.apache.pulsar.common.util.collections.ConcurrentBitSetRecyclable;
 import org.testng.annotations.AfterClass;
 import org.testng.annotations.BeforeClass;
 import org.testng.annotations.DataProvider;
@@ -56,6 +58,7 @@ public class AcknowledgementsGroupingTrackerTest {
     private ClientCnx cnx;
     private ConsumerImpl<?> consumer;
     private EventLoopGroup eventLoopGroup;
+    private AtomicBoolean returnCnx = new AtomicBoolean(true);
 
     @BeforeClass
     public void setup() throws NoSuchFieldException, IllegalAccessException {
@@ -68,12 +71,12 @@ public class AcknowledgementsGroupingTrackerTest {
         ConnectionPool connectionPool = mock(ConnectionPool.class);
         when(client.getCnxPool()).thenReturn(connectionPool);
         doReturn(client).when(consumer).getClient();
-        doReturn(cnx).when(consumer).getClientCnx();
         doReturn(new ConsumerStatsRecorderImpl()).when(consumer).getStats();
         doReturn(new UnAckedMessageTracker().UNACKED_MESSAGE_TRACKER_DISABLED)
                 .when(consumer).getUnAckedMessageTracker();
-        ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
-        when(cnx.ctx()).thenReturn(ctx);
+        ChannelHandlerContext ctx = 
ClientTestFixtures.mockChannelHandlerContext();
+        doAnswer(invocation -> returnCnx.get() ? cnx : 
null).when(consumer).getClientCnx();
+        doReturn(ctx).when(cnx).ctx();
     }
 
     @DataProvider(name = "isNeedReceipt")
@@ -131,8 +134,6 @@ public class AcknowledgementsGroupingTrackerTest {
         tracker.addAcknowledgment(msg6, AckType.Individual, 
Collections.emptyMap());
         assertTrue(tracker.isDuplicate(msg6));
 
-        when(consumer.getClientCnx()).thenReturn(cnx);
-
         tracker.flush();
 
         assertTrue(tracker.isDuplicate(msg1));
@@ -191,8 +192,6 @@ public class AcknowledgementsGroupingTrackerTest {
         tracker.addListAcknowledgment(Collections.singletonList(msg6), 
AckType.Individual, Collections.emptyMap());
         assertTrue(tracker.isDuplicate(msg6));
 
-        when(consumer.getClientCnx()).thenReturn(cnx);
-
         tracker.flush();
 
         assertTrue(tracker.isDuplicate(msg1));
@@ -219,12 +218,13 @@ public class AcknowledgementsGroupingTrackerTest {
 
         assertFalse(tracker.isDuplicate(msg1));
 
-        when(consumer.getClientCnx()).thenReturn(null);
-
-        tracker.addAcknowledgment(msg1, AckType.Individual, 
Collections.emptyMap());
-        assertFalse(tracker.isDuplicate(msg1));
-
-        when(consumer.getClientCnx()).thenReturn(cnx);
+        returnCnx.set(false);
+        try {
+            tracker.addAcknowledgment(msg1, AckType.Individual, 
Collections.emptyMap());
+            assertFalse(tracker.isDuplicate(msg1));
+        } finally {
+            returnCnx.set(true);
+        }
 
         tracker.flush();
         assertFalse(tracker.isDuplicate(msg1));
@@ -248,12 +248,13 @@ public class AcknowledgementsGroupingTrackerTest {
 
         assertFalse(tracker.isDuplicate(msg1));
 
-        when(consumer.getClientCnx()).thenReturn(null);
-
-        tracker.addListAcknowledgment(Collections.singletonList(msg1), 
AckType.Individual, Collections.emptyMap());
-        assertTrue(tracker.isDuplicate(msg1));
-
-        when(consumer.getClientCnx()).thenReturn(cnx);
+        returnCnx.set(false);
+        try {
+            tracker.addListAcknowledgment(Collections.singletonList(msg1), 
AckType.Individual, Collections.emptyMap());
+            assertTrue(tracker.isDuplicate(msg1));
+        } finally {
+            returnCnx.set(true);
+        }
 
         tracker.flush();
         assertFalse(tracker.isDuplicate(msg1));
@@ -313,8 +314,6 @@ public class AcknowledgementsGroupingTrackerTest {
         tracker.addAcknowledgment(msg6, AckType.Individual, 
Collections.emptyMap());
         assertTrue(tracker.isDuplicate(msg6));
 
-        when(consumer.getClientCnx()).thenReturn(cnx);
-
         tracker.flush();
 
         assertTrue(tracker.isDuplicate(msg1));
@@ -375,8 +374,6 @@ public class AcknowledgementsGroupingTrackerTest {
         tracker.addListAcknowledgment(Collections.singletonList(msg6), 
AckType.Individual, Collections.emptyMap());
         assertTrue(tracker.isDuplicate(msg6));
 
-        when(consumer.getClientCnx()).thenReturn(cnx);
-
         tracker.flush();
 
         assertTrue(tracker.isDuplicate(msg1));
diff --git 
a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/BinaryProtoLookupServiceTest.java
 
b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/BinaryProtoLookupServiceTest.java
index 984b201d1ce..d8aa5e5cd08 100644
--- 
a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/BinaryProtoLookupServiceTest.java
+++ 
b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/BinaryProtoLookupServiceTest.java
@@ -20,6 +20,7 @@ package org.apache.pulsar.client.impl;
 
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.anyLong;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
@@ -39,6 +40,7 @@ import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pulsar.client.api.PulsarClientException.LookupException;
@@ -73,8 +75,17 @@ public class BinaryProtoLookupServiceTest {
         CompletableFuture<LookupDataResult> lookupFuture2 = 
CompletableFuture.completedFuture(lookupResult2);
 
         ClientCnx clientCnx = mock(ClientCnx.class);
-        when(clientCnx.newLookup(any(ByteBuf.class), 
anyLong())).thenReturn(lookupFuture1, lookupFuture1,
-                lookupFuture2);
+        AtomicInteger lookupInvocationCounter = new AtomicInteger();
+        doAnswer(invocation -> {
+            ByteBuf byteBuf = invocation.getArgument(0);
+            byteBuf.release();
+            int lookupInvocationCount = 
lookupInvocationCounter.incrementAndGet();
+            if (lookupInvocationCount < 3) {
+                return lookupFuture1;
+            } else {
+                return lookupFuture2;
+            }
+        }).when(clientCnx).newLookup(any(ByteBuf.class), anyLong());
 
         CompletableFuture<ClientCnx> connectionFuture = 
CompletableFuture.completedFuture(clientCnx);
 
diff --git 
a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientCnxTest.java 
b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientCnxTest.java
index 22220805814..d5fbfd22321 100644
--- 
a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientCnxTest.java
+++ 
b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientCnxTest.java
@@ -21,14 +21,11 @@ package org.apache.pulsar.client.impl;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertFalse;
 import static org.testng.Assert.assertTrue;
 import static org.testng.Assert.fail;
 import io.netty.buffer.ByteBuf;
-import io.netty.channel.Channel;
-import io.netty.channel.ChannelFuture;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.EventLoopGroup;
 import io.netty.util.concurrent.DefaultThreadFactory;
@@ -63,14 +60,7 @@ public class ClientCnxTest {
         conf.setOperationTimeoutMs(10);
         conf.setKeepAliveIntervalSeconds(0);
         ClientCnx cnx = new ClientCnx(conf, eventLoop);
-
-        ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
-        Channel channel = mock(Channel.class);
-        when(ctx.channel()).thenReturn(channel);
-        ChannelFuture listenerFuture = mock(ChannelFuture.class);
-        when(listenerFuture.addListener(any())).thenReturn(listenerFuture);
-        when(ctx.writeAndFlush(any())).thenReturn(listenerFuture);
-
+        ChannelHandlerContext ctx = 
ClientTestFixtures.mockChannelHandlerContext();
         cnx.channelActive(ctx);
 
         try {
@@ -89,13 +79,7 @@ public class ClientCnxTest {
         conf.setOperationTimeoutMs(10_000);
         conf.setKeepAliveIntervalSeconds(0);
         ClientCnx cnx = new ClientCnx(conf, eventLoop);
-
-        ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
-        Channel channel = mock(Channel.class);
-        when(ctx.channel()).thenReturn(channel);
-        ChannelFuture listenerFuture = mock(ChannelFuture.class);
-        when(listenerFuture.addListener(any())).thenReturn(listenerFuture);
-        when(ctx.writeAndFlush(any())).thenReturn(listenerFuture);
+        ChannelHandlerContext ctx = 
ClientTestFixtures.mockChannelHandlerContext();
         cnx.channelActive(ctx);
         CountDownLatch countDownLatch = new CountDownLatch(1);
         CompletableFuture<Exception> completableFuture = new 
CompletableFuture<>();
@@ -127,13 +111,7 @@ public class ClientCnxTest {
         conf.setOperationTimeoutMs(10_000);
         conf.setKeepAliveIntervalSeconds(0);
         ClientCnx cnx = new ClientCnx(conf, eventLoop);
-
-        ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
-        Channel channel = mock(Channel.class);
-        when(ctx.channel()).thenReturn(channel);
-        ChannelFuture listenerFuture = mock(ChannelFuture.class);
-        when(listenerFuture.addListener(any())).thenReturn(listenerFuture);
-        when(ctx.writeAndFlush(any())).thenReturn(listenerFuture);
+        ChannelHandlerContext ctx = 
ClientTestFixtures.mockChannelHandlerContext();
         cnx.channelActive(ctx);
         cnx.state = ClientCnx.State.Ready;
         CountDownLatch countDownLatch = new CountDownLatch(1);
@@ -170,13 +148,7 @@ public class ClientCnxTest {
         conf.setOperationTimeoutMs(10_000);
         conf.setKeepAliveIntervalSeconds(0);
         ClientCnx cnx = new ClientCnx(conf, eventLoop);
-
-        ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
-        Channel channel = mock(Channel.class);
-        when(ctx.channel()).thenReturn(channel);
-        ChannelFuture listenerFuture = mock(ChannelFuture.class);
-        when(listenerFuture.addListener(any())).thenReturn(listenerFuture);
-        when(ctx.writeAndFlush(any())).thenReturn(listenerFuture);
+        ChannelHandlerContext ctx = 
ClientTestFixtures.mockChannelHandlerContext();
         cnx.channelActive(ctx);
         for (int i = 0; i < 5001; i++) {
             cnx.newLookup(null, i);
@@ -197,9 +169,7 @@ public class ClientCnxTest {
         conf.setOperationTimeoutMs(10);
         ClientCnx cnx = new ClientCnx(conf, eventLoop);
 
-        ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
-        Channel channel = mock(Channel.class);
-        when(ctx.channel()).thenReturn(channel);
+        ChannelHandlerContext ctx = 
ClientTestFixtures.mockChannelHandlerContext();
 
         Field ctxField = PulsarHandler.class.getDeclaredField("ctx");
         ctxField.setAccessible(true);
@@ -231,9 +201,7 @@ public class ClientCnxTest {
         ClientConfigurationData conf = new ClientConfigurationData();
         ClientCnx cnx = new ClientCnx(conf, eventLoop);
 
-        ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
-        Channel channel = mock(Channel.class);
-        when(ctx.channel()).thenReturn(channel);
+        ChannelHandlerContext ctx = 
ClientTestFixtures.mockChannelHandlerContext();
 
         Field ctxField = PulsarHandler.class.getDeclaredField("ctx");
         ctxField.setAccessible(true);
@@ -246,10 +214,6 @@ public class ClientCnxTest {
         cnxField.setAccessible(true);
         cnxField.set(cnx, ClientCnx.State.SentConnectFrame);
 
-        ChannelFuture listenerFuture = mock(ChannelFuture.class);
-        when(listenerFuture.addListener(any())).thenReturn(listenerFuture);
-        when(ctx.writeAndFlush(any())).thenReturn(listenerFuture);
-
         ByteBuf getLastIdCmd = Commands.newGetLastMessageId(5, requestId);
         CompletableFuture<?> future = cnx.sendGetLastMessageId(getLastIdCmd, 
requestId);
 
@@ -382,13 +346,7 @@ public class ClientCnxTest {
             ClientConfigurationData conf = new ClientConfigurationData();
             ClientCnx cnx = new ClientCnx(conf, eventLoop);
 
-            ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
-            Channel channel = mock(Channel.class);
-            when(ctx.channel()).thenReturn(channel);
-
-            ChannelFuture listenerFuture = mock(ChannelFuture.class);
-            when(listenerFuture.addListener(any())).thenReturn(listenerFuture);
-            when(ctx.writeAndFlush(any())).thenReturn(listenerFuture);
+            ChannelHandlerContext ctx = 
ClientTestFixtures.mockChannelHandlerContext();
 
             Field ctxField = PulsarHandler.class.getDeclaredField("ctx");
             ctxField.setAccessible(true);
diff --git 
a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientTestFixtures.java
 
b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientTestFixtures.java
index ff7d7f12dd4..bac69a7fbd7 100644
--- 
a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientTestFixtures.java
+++ 
b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientTestFixtures.java
@@ -22,11 +22,21 @@ import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.EventLoop;
+import io.netty.util.ReferenceCountUtil;
 import io.netty.util.Timer;
+import io.netty.util.concurrent.EventExecutor;
+import io.netty.util.concurrent.Future;
+import io.netty.util.concurrent.GenericFutureListener;
+import io.netty.util.concurrent.SucceededFuture;
 import java.net.SocketAddress;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
@@ -68,11 +78,7 @@ class ClientTestFixtures {
     }
 
     static PulsarClientImpl mockClientCnx(PulsarClientImpl clientMock) {
-        ClientCnx clientCnxMock = mock(ClientCnx.class, 
Mockito.RETURNS_DEEP_STUBS);
-        
when(clientCnxMock.ctx()).thenReturn(mock(ChannelHandlerContext.class));
-        when(clientCnxMock.sendRequestWithId(any(), anyLong()))
-                
.thenReturn(CompletableFuture.completedFuture(mock(ProducerResponse.class)));
-        
when(clientCnxMock.channel().remoteAddress()).thenReturn(mock(SocketAddress.class));
+        ClientCnx clientCnxMock = mockClientCnx();
         
when(clientMock.getConnection(any())).thenReturn(CompletableFuture.completedFuture(clientCnxMock));
         
when(clientMock.getConnection(anyString())).thenReturn(CompletableFuture.completedFuture(clientCnxMock));
         when(clientMock.getConnection(anyString(), anyInt()))
@@ -87,6 +93,82 @@ class ClientTestFixtures {
         return clientMock;
     }
 
+    public static ClientCnx mockClientCnx() {
+        ClientCnx clientCnxMock = mock(ClientCnx.class, 
Mockito.RETURNS_DEEP_STUBS);
+        ChannelHandlerContext ctx = mockChannelHandlerContext();
+        doReturn(ctx).when(clientCnxMock).ctx();
+        doAnswer(invocation -> {
+            ByteBuf buf = invocation.getArgument(0);
+            buf.release();
+            return 
CompletableFuture.completedFuture(mock(ProducerResponse.class));
+        }).when(clientCnxMock).sendRequestWithId(any(), anyLong());
+        
when(clientCnxMock.channel().remoteAddress()).thenReturn(mock(SocketAddress.class));
+        return clientCnxMock;
+    }
+
+    /**
+     * Mock a ChannelHandlerContext where write and writeAndFlush are always 
successful.
+     * This might not be suitable for all tests.
+     *
+     * @return a mocked ChannelHandlerContext
+     */
+    public static ChannelHandlerContext mockChannelHandlerContext() {
+        ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+
+        // return an empty channel mock from ctx.channel()
+        Channel channel = mock(Channel.class);
+        when(ctx.channel()).thenReturn(channel);
+
+        // handle write and writeAndFlush methods so that the input message is 
released
+
+        // create a listener future that is returned from write and 
writeAndFlush
+        // that immediately completes the listener in the calling thread as if 
it was successful
+        ChannelFuture listenerFuture = mock(ChannelFuture.class);
+        Future<Void> succeededFuture = createSucceededFuture();
+        doAnswer(invocation -> {
+            GenericFutureListener<Future<Void>> listener = 
invocation.getArgument(0);
+            listener.operationComplete(succeededFuture);
+            return listenerFuture;
+        }).when(listenerFuture).addListener(any());
+
+        // handle write and writeAndFlush methods so that the input message is 
released
+        doAnswer(invocation -> {
+            Object msg = invocation.getArgument(0);
+            ReferenceCountUtil.release(msg);
+            return listenerFuture;
+        }).when(ctx).write(any(), any());
+        doAnswer(invocation -> {
+            Object msg = invocation.getArgument(0);
+            ReferenceCountUtil.release(msg);
+            return listenerFuture;
+        }).when(ctx).writeAndFlush(any(), any());
+        doAnswer(invocation -> {
+            Object msg = invocation.getArgument(0);
+            ReferenceCountUtil.release(msg);
+            return listenerFuture;
+        }).when(ctx).writeAndFlush(any());
+
+        return ctx;
+    }
+
+    public static Future<Void> createSucceededFuture() {
+        EventExecutor eventExecutor = mockEventExecutor();
+        // create a succeeded future that is returned from the listener, 
listeners will run in the calling thread
+        // using the mocked EventExecutor
+        SucceededFuture<Void> succeededFuture = new 
SucceededFuture<>(eventExecutor, null);
+        return succeededFuture;
+    }
+
+    public static EventExecutor mockEventExecutor() {
+        // mock an EventExecutor that runs the listener in the calling thread
+        EventExecutor eventExecutor = mock(EventExecutor.class);
+        doAnswer(invocation -> {
+            invocation.getArgument(0, Runnable.class).run();
+            return null;
+        }).when(eventExecutor).execute(any(Runnable.class));
+        return eventExecutor;
+    }
+
     static <T> CompletableFuture<T> createDelayedCompletedFuture(T result, int 
delayMillis) {
         CompletableFuture<T> future = new CompletableFuture<>();
         SCHEDULER.schedule(() -> future.complete(result), delayMillis, 
TimeUnit.MILLISECONDS);

Reply via email to