This is an automated email from the ASF dual-hosted git repository. lhotari pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/pulsar.git
commit 03de43e0b9405578fc07857fdf0db2bcf744082e Author: Lari Hotari <[email protected]> AuthorDate: Wed May 14 10:00:26 2025 +0300 [fix][test] Fix more Netty ByteBuf leaks in tests (#24299) (cherry picked from commit 174245d322574ef422f08a4c249db4e649010c83) --- .../pulsar/broker/service/EntryAndMetadata.java | 2 +- ...ntStickyKeyDispatcherMultipleConsumersTest.java | 495 ++++++++++++--------- .../impl/AcknowledgementsGroupingTrackerTest.java | 45 +- .../client/impl/BinaryProtoLookupServiceTest.java | 15 +- .../apache/pulsar/client/impl/ClientCnxTest.java | 56 +-- .../pulsar/client/impl/ClientTestFixtures.java | 92 +++- 6 files changed, 420 insertions(+), 285 deletions(-) diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/EntryAndMetadata.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/EntryAndMetadata.java index 70643d5de2a..d9cdedf6477 100644 --- a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/EntryAndMetadata.java +++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/EntryAndMetadata.java @@ -45,7 +45,7 @@ public class EntryAndMetadata implements Entry { } @VisibleForTesting - static EntryAndMetadata create(final Entry entry) { + public static EntryAndMetadata create(final Entry entry) { return create(entry, Commands.peekAndCopyMessageMetadata(entry.getDataBuffer(), "", -1)); } diff --git a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/PersistentStickyKeyDispatcherMultipleConsumersTest.java b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/PersistentStickyKeyDispatcherMultipleConsumersTest.java index bf3a79133b0..9b7c98cc30e 100644 --- a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/PersistentStickyKeyDispatcherMultipleConsumersTest.java +++ b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/PersistentStickyKeyDispatcherMultipleConsumersTest.java @@ -27,10 +27,8 @@ import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyList; import static org.mockito.Mockito.anyLong; import static org.mockito.Mockito.anySet; -import static org.mockito.Mockito.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -40,21 +38,27 @@ import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoopGroup; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.SucceededFuture; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Queue; +import java.util.Objects; import java.util.Set; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.TreeSet; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import org.apache.bookkeeper.common.util.OrderedExecutor; +import org.apache.bookkeeper.mledger.AsyncCallbacks; import org.apache.bookkeeper.mledger.Entry; import org.apache.bookkeeper.mledger.Position; import org.apache.bookkeeper.mledger.impl.EntryImpl; @@ -64,10 +68,13 @@ import org.apache.pulsar.broker.PulsarService; import org.apache.pulsar.broker.ServiceConfiguration; import org.apache.pulsar.broker.service.BrokerService; import org.apache.pulsar.broker.service.Consumer; +import org.apache.pulsar.broker.service.EntryAndMetadata; import org.apache.pulsar.broker.service.EntryBatchIndexesAcks; import org.apache.pulsar.broker.service.EntryBatchSizes; import org.apache.pulsar.broker.service.RedeliveryTracker; import org.apache.pulsar.broker.service.StickyKeyConsumerSelector; +import org.apache.pulsar.broker.service.TransportCnx; +import org.apache.pulsar.broker.service.plugin.EntryFilterProvider; import org.apache.pulsar.common.api.proto.KeySharedMeta; import org.apache.pulsar.common.api.proto.KeySharedMode; import org.apache.pulsar.common.api.proto.MessageMetadata; @@ -77,6 +84,7 @@ import org.apache.pulsar.common.protocol.Markers; import org.awaitility.Awaitility; import org.mockito.ArgumentCaptor; import org.testng.Assert; +import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -90,13 +98,14 @@ public class PersistentStickyKeyDispatcherMultipleConsumersTest { private PersistentTopic topicMock; private PersistentSubscription subscriptionMock; private ServiceConfiguration configMock; - private ChannelPromise channelMock; + private Future<Void> succeededFuture; private OrderedExecutor orderedExecutor; private PersistentStickyKeyDispatcherMultipleConsumers persistentDispatcher; final String topicName = "persistent://public/default/testTopic"; final String subscriptionName = "testSubscription"; + private AtomicInteger consumerMockAvailablePermits; @BeforeMethod public void setup() throws Exception { @@ -106,12 +115,17 @@ public class PersistentStickyKeyDispatcherMultipleConsumersTest { doReturn(true).when(configMock).isSubscriptionKeySharedUseConsistentHashing(); doReturn(1).when(configMock).getSubscriptionKeySharedConsistentHashingReplicaPoints(); doReturn(true).when(configMock).isDispatcherDispatchMessagesInSubscriptionThread(); + doReturn(false).when(configMock).isAllowOverrideEntryFilters(); pulsarMock = mock(PulsarService.class); doReturn(configMock).when(pulsarMock).getConfiguration(); + EntryFilterProvider mockEntryFilterProvider = mock(EntryFilterProvider.class); + when(mockEntryFilterProvider.getBrokerEntryFilters()).thenReturn(Collections.emptyList()); + brokerMock = mock(BrokerService.class); doReturn(pulsarMock).when(brokerMock).pulsar(); + when(brokerMock.getEntryFilterProvider()).thenReturn(mockEntryFilterProvider); HierarchyTopicPolicies topicPolicies = new HierarchyTopicPolicies(); topicPolicies.getMaxConsumersPerSubscription().updateBrokerValue(0); @@ -122,7 +136,7 @@ public class PersistentStickyKeyDispatcherMultipleConsumersTest { EventLoopGroup eventLoopGroup = mock(EventLoopGroup.class); doReturn(eventLoopGroup).when(brokerMock).executor(); doAnswer(invocation -> { - orderedExecutor.execute(((Runnable)invocation.getArguments()[0])); + orderedExecutor.execute(invocation.getArgument(0, Runnable.class)); return null; }).when(eventLoopGroup).execute(any(Runnable.class)); @@ -135,12 +149,36 @@ public class PersistentStickyKeyDispatcherMultipleConsumersTest { doReturn(null).when(cursorMock).getLastIndividualDeletedRange(); doReturn(subscriptionName).when(cursorMock).getName(); - consumerMock = mock(Consumer.class); - channelMock = mock(ChannelPromise.class); + consumerMock = createMockConsumer(); + EventExecutor eventExecutor = mock(EventExecutor.class); + doAnswer(invocation -> { + orderedExecutor.execute(invocation.getArgument(0, Runnable.class)); + return null; + }).when(eventExecutor).execute(any(Runnable.class)); + doReturn(false).when(eventExecutor).inEventLoop(); + succeededFuture = new SucceededFuture<>(eventExecutor, null); doReturn("consumer1").when(consumerMock).consumerName(); - doReturn(1000).when(consumerMock).getAvailablePermits(); + consumerMockAvailablePermits = new AtomicInteger(1000); + doAnswer(invocation -> consumerMockAvailablePermits.get()).when(consumerMock).getAvailablePermits(); doReturn(true).when(consumerMock).isWritable(); - doReturn(channelMock).when(consumerMock).sendMessages( + mockSendMessages(consumerMock, null); + + subscriptionMock = mock(PersistentSubscription.class); + when(subscriptionMock.getTopic()).thenReturn(topicMock); + persistentDispatcher = new PersistentStickyKeyDispatcherMultipleConsumers( + topicMock, cursorMock, subscriptionMock, configMock, + new KeySharedMeta().setKeySharedMode(KeySharedMode.AUTO_SPLIT)); + } + + private void mockSendMessages(Consumer consumerMock, java.util.function.Consumer<List<Entry>> entryConsumer) { + doAnswer(invocation -> { + List<Entry> entries = invocation.getArgument(0); + if (entryConsumer != null) { + entryConsumer.accept(entries); + } + entries.stream().filter(Objects::nonNull).forEach(Entry::release); + return succeededFuture; + }).when(consumerMock).sendMessages( anyList(), any(EntryBatchSizes.class), any(EntryBatchIndexesAcks.class), @@ -149,16 +187,25 @@ public class PersistentStickyKeyDispatcherMultipleConsumersTest { anyLong(), any(RedeliveryTracker.class) ); + } - subscriptionMock = mock(PersistentSubscription.class); - persistentDispatcher = new PersistentStickyKeyDispatcherMultipleConsumers( - topicMock, cursorMock, subscriptionMock, configMock, - new KeySharedMeta().setKeySharedMode(KeySharedMode.AUTO_SPLIT)); + protected static Consumer createMockConsumer() { + Consumer consumerMock = mock(Consumer.class); + TransportCnx transportCnx = mock(TransportCnx.class); + doReturn(transportCnx).when(consumerMock).cnx(); + doReturn(true).when(transportCnx).isActive(); + doReturn(100).when(consumerMock).getMaxUnackedMessages(); + doReturn(1).when(consumerMock).getAvgMessagesPerEntry(); + return consumerMock; } + @AfterMethod(alwaysRun = true) public void cleanup() { + if (persistentDispatcher != null && !persistentDispatcher.isClosed()) { + persistentDispatcher.close(); + } if (orderedExecutor != null) { - orderedExecutor.shutdown(); + orderedExecutor.shutdownNow(); orderedExecutor = null; } } @@ -166,8 +213,8 @@ public class PersistentStickyKeyDispatcherMultipleConsumersTest { @Test(timeOut = 10000) public void testAddConsumerWhenClosed() throws Exception { persistentDispatcher.close().get(); - Consumer consumer = mock(Consumer.class); - persistentDispatcher.addConsumer(consumer); + Consumer consumer = createMockConsumer(); + persistentDispatcher.addConsumer(consumer).join(); verify(consumer, times(1)).disconnect(); assertEquals(0, persistentDispatcher.getConsumers().size()); assertTrue(persistentDispatcher.getSelector().getConsumerKeyHashRanges().isEmpty()); @@ -180,19 +227,19 @@ public class PersistentStickyKeyDispatcherMultipleConsumersTest { topicMock, cursorMock, subscriptionMock, configMock, new KeySharedMeta().setKeySharedMode(KeySharedMode.AUTO_SPLIT)); - Consumer consumer0 = mock(Consumer.class); + Consumer consumer0 = createMockConsumer(); when(consumer0.consumerName()).thenReturn("c0-1"); - Consumer consumer1 = mock(Consumer.class); + Consumer consumer1 = createMockConsumer(); when(consumer1.consumerName()).thenReturn("c1"); - Consumer consumer2 = mock(Consumer.class); + Consumer consumer2 = createMockConsumer(); when(consumer2.consumerName()).thenReturn("c2"); - Consumer consumer3 = mock(Consumer.class); + Consumer consumer3 = createMockConsumer(); when(consumer3.consumerName()).thenReturn("c3"); - Consumer consumer4 = mock(Consumer.class); + Consumer consumer4 = createMockConsumer(); when(consumer4.consumerName()).thenReturn("c4"); - Consumer consumer5 = mock(Consumer.class); + Consumer consumer5 = createMockConsumer(); when(consumer5.consumerName()).thenReturn("c5"); - Consumer consumer6 = mock(Consumer.class); + Consumer consumer6 = createMockConsumer(); when(consumer6.consumerName()).thenReturn("c6"); when(cursorMock.getNumberOfEntriesSinceFirstNotAckedMessage()).thenReturn(100L); @@ -255,7 +302,7 @@ public class PersistentStickyKeyDispatcherMultipleConsumersTest { @Test public void testSendMarkerMessage() { try { - persistentDispatcher.addConsumer(consumerMock); + persistentDispatcher.addConsumer(consumerMock).join(); persistentDispatcher.consumerFlow(consumerMock, 1000); } catch (Exception e) { fail("Failed to add mock consumer", e); @@ -264,14 +311,16 @@ public class PersistentStickyKeyDispatcherMultipleConsumersTest { List<Entry> entries = new ArrayList<>(); ByteBuf markerMessage = Markers.newReplicatedSubscriptionsSnapshotRequest("testSnapshotId", "testSourceCluster"); entries.add(EntryImpl.create(1, 1, markerMessage)); - entries.add(EntryImpl.create(1, 2, createMessage("message1", 1))); - entries.add(EntryImpl.create(1, 3, createMessage("message2", 2))); - entries.add(EntryImpl.create(1, 4, createMessage("message3", 3))); - entries.add(EntryImpl.create(1, 5, createMessage("message4", 4))); - entries.add(EntryImpl.create(1, 6, createMessage("message5", 5))); + markerMessage.release(); + entries.add(createEntry(1, 2, "message1", 1)); + entries.add(createEntry(1, 3, "message2", 2)); + entries.add(createEntry(1, 4, "message3", 3)); + entries.add(createEntry(1, 5, "message4", 4)); + entries.add(createEntry(1, 6, "message5", 5)); try { - persistentDispatcher.readEntriesComplete(entries, PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal); + persistentDispatcher.readEntriesComplete(copyEntries(entries), + PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal); } catch (Exception e) { fail("Failed to readEntriesComplete.", e); } @@ -291,243 +340,264 @@ public class PersistentStickyKeyDispatcherMultipleConsumersTest { List<Integer> allTotalMessagesCaptor = totalMessagesCaptor.getAllValues(); Assert.assertEquals(allTotalMessagesCaptor.get(0).intValue(), 5); }); + + entries.forEach(Entry::release); + } + + private static List<Entry> copyEntries(List<Entry> entries) { + return entries.stream().map(entry -> EntryImpl.create((EntryImpl) entry)) + .collect(Collectors.toList()); } @Test(timeOut = 10000) public void testSendMessage() { KeySharedMeta keySharedMeta = new KeySharedMeta().setKeySharedMode(KeySharedMode.STICKY); - PersistentStickyKeyDispatcherMultipleConsumers persistentDispatcher = new PersistentStickyKeyDispatcherMultipleConsumers( + PersistentStickyKeyDispatcherMultipleConsumers + persistentDispatcher = new PersistentStickyKeyDispatcherMultipleConsumers( topicMock, cursorMock, subscriptionMock, configMock, keySharedMeta); try { keySharedMeta.addHashRange() .setStart(0) .setEnd(9); - Consumer consumerMock = mock(Consumer.class); + Consumer consumerMock = createMockConsumer(); doReturn(keySharedMeta).when(consumerMock).getKeySharedMeta(); - persistentDispatcher.addConsumer(consumerMock); + mockSendMessages(consumerMock, null); + persistentDispatcher.addConsumer(consumerMock).join(); persistentDispatcher.consumerFlow(consumerMock, 1000); } catch (Exception e) { fail("Failed to add mock consumer", e); } List<Entry> entries = new ArrayList<>(); - entries.add(EntryImpl.create(1, 1, createMessage("message1", 1))); - entries.add(EntryImpl.create(1, 2, createMessage("message2", 2))); + entries.add(createEntry(1, 1, "message1", 1)); + entries.add(createEntry(1, 2, "message2", 2)); try { //Should success,see issue #8960 - persistentDispatcher.readEntriesComplete(entries, PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal); + persistentDispatcher.readEntriesComplete(copyEntries(entries), + PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal); } catch (Exception e) { fail("Failed to readEntriesComplete.", e); } + + entries.forEach(Entry::release); } @Test - public void testSkipRedeliverTemporally() { - final Consumer slowConsumerMock = mock(Consumer.class); - final ChannelPromise slowChannelMock = mock(ChannelPromise.class); - // add entries to redeliver and read target + public void testSkipRedeliverTemporally() throws InterruptedException { + // add first consumer + persistentDispatcher.addConsumer(consumerMock).join(); + // add slow consumer + final Consumer slowConsumerMock = createMockConsumer(); + doReturn("consumer2").when(slowConsumerMock).consumerName(); + AtomicInteger slowConsumerAvailablePermits = new AtomicInteger(0); + doAnswer(invocation -> { + return slowConsumerAvailablePermits.get(); + }).when(slowConsumerMock).getAvailablePermits(); + persistentDispatcher.addConsumer(slowConsumerMock).join(); + + StickyKeyConsumerSelector selector = persistentDispatcher.getSelector(); + String keyForConsumer = generateKeyForConsumer(selector, consumerMock); + String keyForSlowConsumer = generateKeyForConsumer(selector, slowConsumerMock); + + Set<Position> alreadySent = new ConcurrentSkipListSet<>(); + + final List<Entry> allEntries = new ArrayList<>(); + allEntries.add(createEntry(1, 1, "message1", 1, keyForSlowConsumer)); + allEntries.add(createEntry(1, 2, "message2", 2, keyForSlowConsumer)); + allEntries.add(createEntry(1, 3, "message3", 3, keyForConsumer)); + + // add first entry to redeliver initially final List<Entry> redeliverEntries = new ArrayList<>(); - redeliverEntries.add(EntryImpl.create(1, 1, createMessage("message1", 1, "key1"))); - final List<Entry> readEntries = new ArrayList<>(); - readEntries.add(EntryImpl.create(1, 2, createMessage("message2", 2, "key1"))); - readEntries.add(EntryImpl.create(1, 3, createMessage("message3", 3, "key2"))); + redeliverEntries.add(allEntries.get(0)); try { - Field totalAvailablePermitsField = PersistentDispatcherMultipleConsumers.class.getDeclaredField("totalAvailablePermits"); + Field totalAvailablePermitsField = + PersistentDispatcherMultipleConsumers.class.getDeclaredField("totalAvailablePermits"); totalAvailablePermitsField.setAccessible(true); totalAvailablePermitsField.set(persistentDispatcher, 1000); - - doAnswer(invocationOnMock -> { - ((PersistentStickyKeyDispatcherMultipleConsumers) invocationOnMock.getArgument(2)) - .readEntriesComplete(readEntries, PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal); - return null; - }).when(cursorMock).asyncReadEntriesOrWait( - anyInt(), anyLong(), any(PersistentStickyKeyDispatcherMultipleConsumers.class), - eq(PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal), any()); } catch (Exception e) { fail("Failed to set to field", e); } - // Create 2Consumers - try { - doReturn("consumer2").when(slowConsumerMock).consumerName(); - // Change slowConsumer availablePermits to 0 and back to normal - when(slowConsumerMock.getAvailablePermits()) - .thenReturn(0) - .thenReturn(1); - doReturn(true).when(slowConsumerMock).isWritable(); - doReturn(slowChannelMock).when(slowConsumerMock).sendMessages( - anyList(), - any(EntryBatchSizes.class), - any(EntryBatchIndexesAcks.class), - anyInt(), - anyLong(), - anyLong(), - any(RedeliveryTracker.class) - ); - - persistentDispatcher.addConsumer(consumerMock); - persistentDispatcher.addConsumer(slowConsumerMock); - } catch (Exception e) { - fail("Failed to add mock consumer", e); - } + // Mock Cursor#asyncReplayEntries + doAnswer(invocationOnMock -> { + Set<Position> positionsArg = invocationOnMock.getArgument(0); + Set<Position> positions = new TreeSet<>(positionsArg); + List<Entry> entries = allEntries.stream() + .filter(entry -> entry.getLedgerId() != -1 && positions.contains(entry.getPosition())) + .toList(); + AsyncCallbacks.ReadEntriesCallback callback = invocationOnMock.getArgument(1); + Object ctx = invocationOnMock.getArgument(2); + callback.readEntriesComplete(copyEntries(entries), ctx); + return Collections.emptySet(); + }).when(cursorMock).asyncReplayEntries(anySet(), any(), any(), anyBoolean()); - // run PersistentStickyKeyDispatcherMultipleConsumers#sendMessagesToConsumers - // run readMoreEntries internally (and skip internally) - // Change slowConsumer availablePermits to 1 - // run PersistentStickyKeyDispatcherMultipleConsumers#sendMessagesToConsumers internally - // and then stop to dispatch to slowConsumer - if (persistentDispatcher.sendMessagesToConsumers(PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal, - redeliverEntries, true)) { - persistentDispatcher.readMoreEntriesAsync(); - } + doAnswer(invocationOnMock -> { + int maxEntries = invocationOnMock.getArgument(0); + AsyncCallbacks.ReadEntriesCallback callback = invocationOnMock.getArgument(2); + List<Entry> entries = allEntries.stream() + .filter(entry -> entry.getLedgerId() != -1 && !alreadySent.contains(entry.getPosition())) + .limit(maxEntries) + .toList(); + Object ctx = invocationOnMock.getArgument(3); + callback.readEntriesComplete(copyEntries(entries), ctx); + return null; + }).when(cursorMock).asyncReadEntriesOrWait(anyInt(), anyLong(), any(), any(), any()); + + doReturn(true).when(slowConsumerMock).isWritable(); + CountDownLatch message3Sent = new CountDownLatch(1); + mockSendMessages(consumerMock, entries -> { + entries.forEach(entry -> { + alreadySent.add(entry.getPosition()); + }); + boolean message3Found = entries.stream() + .anyMatch(entry -> entry.getLedgerId() == 1 && entry.getEntryId() == 3); + if (message3Found) { + message3Sent.countDown(); + } + }); + CountDownLatch slowConsumerMessagesSent = new CountDownLatch(2); + mockSendMessages(slowConsumerMock, entries -> { + entries.forEach(entry -> { + alreadySent.add(entry.getPosition()); + slowConsumerMessagesSent.countDown(); + }); + }); - Awaitility.await().untilAsserted(() -> { - verify(consumerMock, times(1)).sendMessages( - argThat(arg -> { - assertEquals(arg.size(), 1); - Entry entry = arg.get(0); - assertEquals(entry.getLedgerId(), 1); - assertEquals(entry.getEntryId(), 3); - return true; - }), - any(EntryBatchSizes.class), - any(EntryBatchIndexesAcks.class), - anyInt(), - anyLong(), - anyLong(), - any(RedeliveryTracker.class) - ); + // add entries to redeliver + redeliverEntries.forEach(entry -> { + // calculate hash + EntryAndMetadata entryAndMetadata = EntryAndMetadata.create(entry); + int stickyKeyHash = StickyKeyConsumerSelector.makeStickyKeyHash(entryAndMetadata.getStickyKey()); + // add to redeliver + persistentDispatcher.addMessageToReplay(entry.getLedgerId(), entry.getEntryId(), stickyKeyHash); }); - verify(slowConsumerMock, times(0)).sendMessages( - anyList(), - any(EntryBatchSizes.class), - any(EntryBatchIndexesAcks.class), - anyInt(), - anyLong(), - anyLong(), - any(RedeliveryTracker.class) - ); + + // trigger readMoreEntries, will handle redelivery logic and skip slow consumer + persistentDispatcher.readMoreEntriesAsync(); + + assertTrue(message3Sent.await(5, TimeUnit.SECONDS)); + + // verify that slow consumer messages are not sent before message3 to "consumer" + assertEquals(slowConsumerMessagesSent.getCount(), 2); + + // set permits to 2 + slowConsumerAvailablePermits.set(2); + + // now wait for slow consumer messages since there are permits + assertTrue(slowConsumerMessagesSent.await(5, TimeUnit.SECONDS)); + + allEntries.forEach(Entry::release); } @Test(timeOut = 30000) public void testMessageRedelivery() throws Exception { - final Queue<Position> actualEntriesToConsumer1 = new ConcurrentLinkedQueue<>(); - final Queue<Position> actualEntriesToConsumer2 = new ConcurrentLinkedQueue<>(); + final List<Position> actualEntriesToConsumer1 = new CopyOnWriteArrayList<>(); + final List<Position> actualEntriesToConsumer2 = new CopyOnWriteArrayList<>(); - final Queue<Position> expectedEntriesToConsumer1 = new ConcurrentLinkedQueue<>(); - expectedEntriesToConsumer1.add(PositionImpl.get(1, 1)); - final Queue<Position> expectedEntriesToConsumer2 = new ConcurrentLinkedQueue<>(); - expectedEntriesToConsumer2.add(PositionImpl.get(1, 2)); - expectedEntriesToConsumer2.add(PositionImpl.get(1, 3)); + final List<Position> expectedEntriesToConsumer1 = new CopyOnWriteArrayList<>(); + final List<Position> expectedEntriesToConsumer2 = new CopyOnWriteArrayList<>(); - final AtomicInteger remainingEntriesNum = new AtomicInteger( - expectedEntriesToConsumer1.size() + expectedEntriesToConsumer2.size()); - - // Messages with key1 are routed to consumer1 and messages with key2 are routed to consumer2 - final List<Entry> allEntries = new ArrayList<>(); - allEntries.add(EntryImpl.create(1, 1, createMessage("message1", 1, "key2"))); - allEntries.add(EntryImpl.create(1, 2, createMessage("message2", 2, "key1"))); - allEntries.add(EntryImpl.create(1, 3, createMessage("message3", 3, "key1"))); - allEntries.forEach(entry -> ((EntryImpl) entry).retain()); + final CountDownLatch remainingEntriesNum = new CountDownLatch(3); - final List<Entry> redeliverEntries = new ArrayList<>(); - redeliverEntries.add(allEntries.get(0)); // message1 - final List<Entry> readEntries = new ArrayList<>(); - readEntries.add(allEntries.get(2)); // message3 - - final Consumer consumer1 = mock(Consumer.class); + final Consumer consumer1 = createMockConsumer(); doReturn("consumer1").when(consumer1).consumerName(); // Change availablePermits of consumer1 to 0 and then back to normal when(consumer1.getAvailablePermits()).thenReturn(0).thenReturn(10); doReturn(true).when(consumer1).isWritable(); - doAnswer(invocationOnMock -> { - @SuppressWarnings("unchecked") - List<Entry> entries = (List<Entry>) invocationOnMock.getArgument(0); + mockSendMessages(consumer1, entries -> { for (Entry entry : entries) { - remainingEntriesNum.decrementAndGet(); actualEntriesToConsumer1.add(entry.getPosition()); + remainingEntriesNum.countDown(); } - return channelMock; - }).when(consumer1).sendMessages(anyList(), any(EntryBatchSizes.class), any(EntryBatchIndexesAcks.class), - anyInt(), anyLong(), anyLong(), any(RedeliveryTracker.class)); + }); - final Consumer consumer2 = mock(Consumer.class); + final Consumer consumer2 = createMockConsumer(); doReturn("consumer2").when(consumer2).consumerName(); when(consumer2.getAvailablePermits()).thenReturn(10); doReturn(true).when(consumer2).isWritable(); - doAnswer(invocationOnMock -> { - @SuppressWarnings("unchecked") - List<Entry> entries = (List<Entry>) invocationOnMock.getArgument(0); + mockSendMessages(consumer2, entries -> { for (Entry entry : entries) { - remainingEntriesNum.decrementAndGet(); actualEntriesToConsumer2.add(entry.getPosition()); + remainingEntriesNum.countDown(); } - return channelMock; - }).when(consumer2).sendMessages(anyList(), any(EntryBatchSizes.class), any(EntryBatchIndexesAcks.class), - anyInt(), anyLong(), anyLong(), any(RedeliveryTracker.class)); + }); - persistentDispatcher.addConsumer(consumer1); - persistentDispatcher.addConsumer(consumer2); + persistentDispatcher.addConsumer(consumer1).join(); + persistentDispatcher.addConsumer(consumer2).join(); final Field totalAvailablePermitsField = PersistentDispatcherMultipleConsumers.class .getDeclaredField("totalAvailablePermits"); totalAvailablePermitsField.setAccessible(true); totalAvailablePermitsField.set(persistentDispatcher, 1000); - final Field redeliveryMessagesField = PersistentDispatcherMultipleConsumers.class - .getDeclaredField("redeliveryMessages"); - redeliveryMessagesField.setAccessible(true); - MessageRedeliveryController redeliveryMessages = (MessageRedeliveryController) redeliveryMessagesField - .get(persistentDispatcher); - redeliveryMessages.add(allEntries.get(0).getLedgerId(), allEntries.get(0).getEntryId(), - getStickyKeyHash(allEntries.get(0))); // message1 - redeliveryMessages.add(allEntries.get(1).getLedgerId(), allEntries.get(1).getEntryId(), - getStickyKeyHash(allEntries.get(1))); // message2 + StickyKeyConsumerSelector selector = persistentDispatcher.getSelector(); + + String keyForConsumer1 = generateKeyForConsumer(selector, consumer1); + String keyForConsumer2 = generateKeyForConsumer(selector, consumer2); + + // Messages with key1 are routed to consumer1 and messages with key2 are routed to consumer2 + final List<Entry> allEntries = new ArrayList<>(); + allEntries.add(createEntry(1, 1, "message1", 1, keyForConsumer1)); + allEntries.add(createEntry(1, 2, "message2", 2, keyForConsumer1)); + allEntries.add(createEntry(1, 3, "message3", 3, keyForConsumer2)); + + // add first entry to redeliver initially + final List<Entry> redeliverEntries = new ArrayList<>(); + redeliverEntries.add(allEntries.get(0)); // message1 + + expectedEntriesToConsumer1.add(allEntries.get(0).getPosition()); + expectedEntriesToConsumer1.add(allEntries.get(1).getPosition()); + expectedEntriesToConsumer2.add(allEntries.get(2).getPosition()); // Mock Cursor#asyncReplayEntries doAnswer(invocationOnMock -> { - @SuppressWarnings("unchecked") - Set<Position> positions = (Set<Position>) invocationOnMock.getArgument(0); - List<Entry> entries = allEntries.stream().filter(entry -> positions.contains(entry.getPosition())) + Set<Position> positionsArg = invocationOnMock.getArgument(0); + Set<Position> positions = new TreeSet<>(positionsArg); + Set<Position> alreadyReceived = new TreeSet<>(); + alreadyReceived.addAll(actualEntriesToConsumer1); + alreadyReceived.addAll(actualEntriesToConsumer2); + List<Entry> entries = allEntries.stream().filter(entry -> entry.getLedgerId() != -1 + && positions.contains(entry.getPosition()) + && !alreadyReceived.contains(entry.getPosition())) .collect(Collectors.toList()); - if (!entries.isEmpty()) { - ((PersistentStickyKeyDispatcherMultipleConsumers) invocationOnMock.getArgument(1)) - .readEntriesComplete(entries, PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Replay); - } - return Collections.emptySet(); - }).when(cursorMock).asyncReplayEntries(anySet(), any(PersistentStickyKeyDispatcherMultipleConsumers.class), - eq(PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Replay), anyBoolean()); + AsyncCallbacks.ReadEntriesCallback callback = invocationOnMock.getArgument(1); + Object ctx = invocationOnMock.getArgument(2); + callback.readEntriesComplete(copyEntries(entries), ctx); + return alreadyReceived; + }).when(cursorMock).asyncReplayEntries(anySet(), any(), any(), anyBoolean()); // Mock Cursor#asyncReadEntriesOrWait - AtomicBoolean asyncReadEntriesOrWaitCalled = new AtomicBoolean(); doAnswer(invocationOnMock -> { - if (asyncReadEntriesOrWaitCalled.compareAndSet(false, true)) { - ((PersistentStickyKeyDispatcherMultipleConsumers) invocationOnMock.getArgument(2)) - .readEntriesComplete(readEntries, PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal); - } else { - ((PersistentStickyKeyDispatcherMultipleConsumers) invocationOnMock.getArgument(2)) - .readEntriesComplete(Collections.emptyList(), PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal); - } + int maxEntries = invocationOnMock.getArgument(0); + Set<Position> alreadyReceived = new TreeSet<>(); + alreadyReceived.addAll(actualEntriesToConsumer1); + alreadyReceived.addAll(actualEntriesToConsumer2); + List<Entry> entries = allEntries.stream() + .filter(entry -> entry.getLedgerId() != -1 && !alreadyReceived.contains(entry.getPosition())) + .limit(maxEntries) + .collect(Collectors.toList()); + AsyncCallbacks.ReadEntriesCallback callback = invocationOnMock.getArgument(2); + Object ctx = invocationOnMock.getArgument(3); + callback.readEntriesComplete(copyEntries(entries), ctx); return null; - }).when(cursorMock).asyncReadEntriesOrWait(anyInt(), anyLong(), - any(PersistentStickyKeyDispatcherMultipleConsumers.class), - eq(PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Normal), any()); - - // (1) Run sendMessagesToConsumers - // (2) Attempts to send message1 to consumer1 but skipped because availablePermits is 0 - // (3) Change availablePermits of consumer1 to 10 - // (4) Run readMoreEntries internally - // (5) Run sendMessagesToConsumers internally - // (6) Attempts to send message3 to consumer2 but skipped because redeliveryMessages contains message2 - persistentDispatcher.sendMessagesToConsumers(PersistentStickyKeyDispatcherMultipleConsumers.ReadType.Replay, - redeliverEntries, true); - while (remainingEntriesNum.get() > 0) { - // (7) Run readMoreEntries and resend message1 to consumer1 and message2-3 to consumer2 - persistentDispatcher.readMoreEntries(); - } + }).when(cursorMock).asyncReadEntriesOrWait(anyInt(), anyLong(), any(), any(), any()); + + // add entries to redeliver + redeliverEntries.forEach(entry -> { + // calculate hash + EntryAndMetadata entryAndMetadata = EntryAndMetadata.create(entry); + int stickyKeyHash = StickyKeyConsumerSelector.makeStickyKeyHash(entryAndMetadata.getStickyKey()); + // add to redeliver + persistentDispatcher.addMessageToReplay(entry.getLedgerId(), entry.getEntryId(), stickyKeyHash); + }); + + // trigger logic to read entries, includes redelivery logic + persistentDispatcher.readMoreEntriesAsync(); + + assertTrue(remainingEntriesNum.await(5, TimeUnit.SECONDS)); assertThat(actualEntriesToConsumer1).containsExactlyElementsOf(expectedEntriesToConsumer1); assertThat(actualEntriesToConsumer2).containsExactlyElementsOf(expectedEntriesToConsumer2); @@ -535,22 +605,39 @@ public class PersistentStickyKeyDispatcherMultipleConsumersTest { allEntries.forEach(entry -> entry.release()); } - private ByteBuf createMessage(String message, int sequenceId) { - return createMessage(message, sequenceId, "testKey"); + private String generateKeyForConsumer(StickyKeyConsumerSelector selector, Consumer consumer) { + int i = 0; + while (!Thread.currentThread().isInterrupted()) { + String key = "key" + i++; + Consumer selectedConsumer = selector.select(key.getBytes(UTF_8)); + if (selectedConsumer == consumer) { + return key; + } + } + return null; } - private ByteBuf createMessage(String message, int sequenceId, String key) { + private EntryImpl createEntry(long ledgerId, long entryId, String message, long sequenceId) { + return createEntry(ledgerId, entryId, message, sequenceId, "testKey"); + } + + private EntryImpl createEntry(long ledgerId, long entryId, String message, long sequenceId, String key) { + ByteBuf data = createMessage(message, sequenceId, key); + EntryImpl entry = EntryImpl.create(ledgerId, entryId, data); + data.release(); + return entry; + } + + private ByteBuf createMessage(String message, long sequenceId, String key) { MessageMetadata messageMetadata = new MessageMetadata() .setSequenceId(sequenceId) .setProducerName("testProducer") .setPartitionKey(key) .setPartitionKeyB64Encoded(false) .setPublishTime(System.currentTimeMillis()); - return serializeMetadataAndPayload(Commands.ChecksumType.Crc32c, messageMetadata, Unpooled.copiedBuffer(message.getBytes(UTF_8))); - } - - private int getStickyKeyHash(Entry entry) { - byte[] stickyKey = Commands.peekStickyKey(entry.getDataBuffer(), topicName, subscriptionName); - return StickyKeyConsumerSelector.makeStickyKeyHash(stickyKey); + ByteBuf payload = Unpooled.copiedBuffer(message.getBytes(UTF_8)); + ByteBuf byteBuf = serializeMetadataAndPayload(Commands.ChecksumType.Crc32c, messageMetadata, payload); + payload.release(); + return byteBuf; } } diff --git a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/AcknowledgementsGroupingTrackerTest.java b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/AcknowledgementsGroupingTrackerTest.java index 1d1a6f85bfd..efa398afba6 100644 --- a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/AcknowledgementsGroupingTrackerTest.java +++ b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/AcknowledgementsGroupingTrackerTest.java @@ -18,6 +18,7 @@ */ package org.apache.pulsar.client.impl; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -37,15 +38,16 @@ import java.util.Collections; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.pulsar.client.api.MessageId; import org.apache.pulsar.client.api.MessageIdAdv; import org.apache.pulsar.client.impl.conf.ClientConfigurationData; import org.apache.pulsar.client.impl.conf.ConsumerConfigurationData; import org.apache.pulsar.client.util.TimedCompletableFuture; import org.apache.pulsar.common.api.proto.CommandAck.AckType; -import org.apache.pulsar.common.util.collections.ConcurrentBitSetRecyclable; import org.apache.pulsar.common.util.collections.ConcurrentOpenHashMap; import org.apache.pulsar.common.api.proto.ProtocolVersion; +import org.apache.pulsar.common.util.collections.ConcurrentBitSetRecyclable; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.DataProvider; @@ -56,6 +58,7 @@ public class AcknowledgementsGroupingTrackerTest { private ClientCnx cnx; private ConsumerImpl<?> consumer; private EventLoopGroup eventLoopGroup; + private AtomicBoolean returnCnx = new AtomicBoolean(true); @BeforeClass public void setup() throws NoSuchFieldException, IllegalAccessException { @@ -68,12 +71,12 @@ public class AcknowledgementsGroupingTrackerTest { ConnectionPool connectionPool = mock(ConnectionPool.class); when(client.getCnxPool()).thenReturn(connectionPool); doReturn(client).when(consumer).getClient(); - doReturn(cnx).when(consumer).getClientCnx(); doReturn(new ConsumerStatsRecorderImpl()).when(consumer).getStats(); doReturn(new UnAckedMessageTracker().UNACKED_MESSAGE_TRACKER_DISABLED) .when(consumer).getUnAckedMessageTracker(); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - when(cnx.ctx()).thenReturn(ctx); + ChannelHandlerContext ctx = ClientTestFixtures.mockChannelHandlerContext(); + doAnswer(invocation -> returnCnx.get() ? cnx : null).when(consumer).getClientCnx(); + doReturn(ctx).when(cnx).ctx(); } @DataProvider(name = "isNeedReceipt") @@ -131,8 +134,6 @@ public class AcknowledgementsGroupingTrackerTest { tracker.addAcknowledgment(msg6, AckType.Individual, Collections.emptyMap()); assertTrue(tracker.isDuplicate(msg6)); - when(consumer.getClientCnx()).thenReturn(cnx); - tracker.flush(); assertTrue(tracker.isDuplicate(msg1)); @@ -191,8 +192,6 @@ public class AcknowledgementsGroupingTrackerTest { tracker.addListAcknowledgment(Collections.singletonList(msg6), AckType.Individual, Collections.emptyMap()); assertTrue(tracker.isDuplicate(msg6)); - when(consumer.getClientCnx()).thenReturn(cnx); - tracker.flush(); assertTrue(tracker.isDuplicate(msg1)); @@ -219,12 +218,13 @@ public class AcknowledgementsGroupingTrackerTest { assertFalse(tracker.isDuplicate(msg1)); - when(consumer.getClientCnx()).thenReturn(null); - - tracker.addAcknowledgment(msg1, AckType.Individual, Collections.emptyMap()); - assertFalse(tracker.isDuplicate(msg1)); - - when(consumer.getClientCnx()).thenReturn(cnx); + returnCnx.set(false); + try { + tracker.addAcknowledgment(msg1, AckType.Individual, Collections.emptyMap()); + assertFalse(tracker.isDuplicate(msg1)); + } finally { + returnCnx.set(true); + } tracker.flush(); assertFalse(tracker.isDuplicate(msg1)); @@ -248,12 +248,13 @@ public class AcknowledgementsGroupingTrackerTest { assertFalse(tracker.isDuplicate(msg1)); - when(consumer.getClientCnx()).thenReturn(null); - - tracker.addListAcknowledgment(Collections.singletonList(msg1), AckType.Individual, Collections.emptyMap()); - assertTrue(tracker.isDuplicate(msg1)); - - when(consumer.getClientCnx()).thenReturn(cnx); + returnCnx.set(false); + try { + tracker.addListAcknowledgment(Collections.singletonList(msg1), AckType.Individual, Collections.emptyMap()); + assertTrue(tracker.isDuplicate(msg1)); + } finally { + returnCnx.set(true); + } tracker.flush(); assertFalse(tracker.isDuplicate(msg1)); @@ -313,8 +314,6 @@ public class AcknowledgementsGroupingTrackerTest { tracker.addAcknowledgment(msg6, AckType.Individual, Collections.emptyMap()); assertTrue(tracker.isDuplicate(msg6)); - when(consumer.getClientCnx()).thenReturn(cnx); - tracker.flush(); assertTrue(tracker.isDuplicate(msg1)); @@ -375,8 +374,6 @@ public class AcknowledgementsGroupingTrackerTest { tracker.addListAcknowledgment(Collections.singletonList(msg6), AckType.Individual, Collections.emptyMap()); assertTrue(tracker.isDuplicate(msg6)); - when(consumer.getClientCnx()).thenReturn(cnx); - tracker.flush(); assertTrue(tracker.isDuplicate(msg1)); diff --git a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/BinaryProtoLookupServiceTest.java b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/BinaryProtoLookupServiceTest.java index 984b201d1ce..d8aa5e5cd08 100644 --- a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/BinaryProtoLookupServiceTest.java +++ b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/BinaryProtoLookupServiceTest.java @@ -20,6 +20,7 @@ package org.apache.pulsar.client.impl; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -39,6 +40,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.apache.commons.lang3.tuple.Pair; import org.apache.pulsar.client.api.PulsarClientException.LookupException; @@ -73,8 +75,17 @@ public class BinaryProtoLookupServiceTest { CompletableFuture<LookupDataResult> lookupFuture2 = CompletableFuture.completedFuture(lookupResult2); ClientCnx clientCnx = mock(ClientCnx.class); - when(clientCnx.newLookup(any(ByteBuf.class), anyLong())).thenReturn(lookupFuture1, lookupFuture1, - lookupFuture2); + AtomicInteger lookupInvocationCounter = new AtomicInteger(); + doAnswer(invocation -> { + ByteBuf byteBuf = invocation.getArgument(0); + byteBuf.release(); + int lookupInvocationCount = lookupInvocationCounter.incrementAndGet(); + if (lookupInvocationCount < 3) { + return lookupFuture1; + } else { + return lookupFuture2; + } + }).when(clientCnx).newLookup(any(ByteBuf.class), anyLong()); CompletableFuture<ClientCnx> connectionFuture = CompletableFuture.completedFuture(clientCnx); diff --git a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientCnxTest.java b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientCnxTest.java index 22220805814..d5fbfd22321 100644 --- a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientCnxTest.java +++ b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientCnxTest.java @@ -21,14 +21,11 @@ package org.apache.pulsar.client.impl; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; import io.netty.buffer.ByteBuf; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.EventLoopGroup; import io.netty.util.concurrent.DefaultThreadFactory; @@ -63,14 +60,7 @@ public class ClientCnxTest { conf.setOperationTimeoutMs(10); conf.setKeepAliveIntervalSeconds(0); ClientCnx cnx = new ClientCnx(conf, eventLoop); - - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - Channel channel = mock(Channel.class); - when(ctx.channel()).thenReturn(channel); - ChannelFuture listenerFuture = mock(ChannelFuture.class); - when(listenerFuture.addListener(any())).thenReturn(listenerFuture); - when(ctx.writeAndFlush(any())).thenReturn(listenerFuture); - + ChannelHandlerContext ctx = ClientTestFixtures.mockChannelHandlerContext(); cnx.channelActive(ctx); try { @@ -89,13 +79,7 @@ public class ClientCnxTest { conf.setOperationTimeoutMs(10_000); conf.setKeepAliveIntervalSeconds(0); ClientCnx cnx = new ClientCnx(conf, eventLoop); - - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - Channel channel = mock(Channel.class); - when(ctx.channel()).thenReturn(channel); - ChannelFuture listenerFuture = mock(ChannelFuture.class); - when(listenerFuture.addListener(any())).thenReturn(listenerFuture); - when(ctx.writeAndFlush(any())).thenReturn(listenerFuture); + ChannelHandlerContext ctx = ClientTestFixtures.mockChannelHandlerContext(); cnx.channelActive(ctx); CountDownLatch countDownLatch = new CountDownLatch(1); CompletableFuture<Exception> completableFuture = new CompletableFuture<>(); @@ -127,13 +111,7 @@ public class ClientCnxTest { conf.setOperationTimeoutMs(10_000); conf.setKeepAliveIntervalSeconds(0); ClientCnx cnx = new ClientCnx(conf, eventLoop); - - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - Channel channel = mock(Channel.class); - when(ctx.channel()).thenReturn(channel); - ChannelFuture listenerFuture = mock(ChannelFuture.class); - when(listenerFuture.addListener(any())).thenReturn(listenerFuture); - when(ctx.writeAndFlush(any())).thenReturn(listenerFuture); + ChannelHandlerContext ctx = ClientTestFixtures.mockChannelHandlerContext(); cnx.channelActive(ctx); cnx.state = ClientCnx.State.Ready; CountDownLatch countDownLatch = new CountDownLatch(1); @@ -170,13 +148,7 @@ public class ClientCnxTest { conf.setOperationTimeoutMs(10_000); conf.setKeepAliveIntervalSeconds(0); ClientCnx cnx = new ClientCnx(conf, eventLoop); - - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - Channel channel = mock(Channel.class); - when(ctx.channel()).thenReturn(channel); - ChannelFuture listenerFuture = mock(ChannelFuture.class); - when(listenerFuture.addListener(any())).thenReturn(listenerFuture); - when(ctx.writeAndFlush(any())).thenReturn(listenerFuture); + ChannelHandlerContext ctx = ClientTestFixtures.mockChannelHandlerContext(); cnx.channelActive(ctx); for (int i = 0; i < 5001; i++) { cnx.newLookup(null, i); @@ -197,9 +169,7 @@ public class ClientCnxTest { conf.setOperationTimeoutMs(10); ClientCnx cnx = new ClientCnx(conf, eventLoop); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - Channel channel = mock(Channel.class); - when(ctx.channel()).thenReturn(channel); + ChannelHandlerContext ctx = ClientTestFixtures.mockChannelHandlerContext(); Field ctxField = PulsarHandler.class.getDeclaredField("ctx"); ctxField.setAccessible(true); @@ -231,9 +201,7 @@ public class ClientCnxTest { ClientConfigurationData conf = new ClientConfigurationData(); ClientCnx cnx = new ClientCnx(conf, eventLoop); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - Channel channel = mock(Channel.class); - when(ctx.channel()).thenReturn(channel); + ChannelHandlerContext ctx = ClientTestFixtures.mockChannelHandlerContext(); Field ctxField = PulsarHandler.class.getDeclaredField("ctx"); ctxField.setAccessible(true); @@ -246,10 +214,6 @@ public class ClientCnxTest { cnxField.setAccessible(true); cnxField.set(cnx, ClientCnx.State.SentConnectFrame); - ChannelFuture listenerFuture = mock(ChannelFuture.class); - when(listenerFuture.addListener(any())).thenReturn(listenerFuture); - when(ctx.writeAndFlush(any())).thenReturn(listenerFuture); - ByteBuf getLastIdCmd = Commands.newGetLastMessageId(5, requestId); CompletableFuture<?> future = cnx.sendGetLastMessageId(getLastIdCmd, requestId); @@ -382,13 +346,7 @@ public class ClientCnxTest { ClientConfigurationData conf = new ClientConfigurationData(); ClientCnx cnx = new ClientCnx(conf, eventLoop); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - Channel channel = mock(Channel.class); - when(ctx.channel()).thenReturn(channel); - - ChannelFuture listenerFuture = mock(ChannelFuture.class); - when(listenerFuture.addListener(any())).thenReturn(listenerFuture); - when(ctx.writeAndFlush(any())).thenReturn(listenerFuture); + ChannelHandlerContext ctx = ClientTestFixtures.mockChannelHandlerContext(); Field ctxField = PulsarHandler.class.getDeclaredField("ctx"); ctxField.setAccessible(true); diff --git a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientTestFixtures.java b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientTestFixtures.java index ff7d7f12dd4..bac69a7fbd7 100644 --- a/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientTestFixtures.java +++ b/pulsar-client/src/test/java/org/apache/pulsar/client/impl/ClientTestFixtures.java @@ -22,11 +22,21 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.EventLoop; +import io.netty.util.ReferenceCountUtil; import io.netty.util.Timer; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.SucceededFuture; import java.net.SocketAddress; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; @@ -68,11 +78,7 @@ class ClientTestFixtures { } static PulsarClientImpl mockClientCnx(PulsarClientImpl clientMock) { - ClientCnx clientCnxMock = mock(ClientCnx.class, Mockito.RETURNS_DEEP_STUBS); - when(clientCnxMock.ctx()).thenReturn(mock(ChannelHandlerContext.class)); - when(clientCnxMock.sendRequestWithId(any(), anyLong())) - .thenReturn(CompletableFuture.completedFuture(mock(ProducerResponse.class))); - when(clientCnxMock.channel().remoteAddress()).thenReturn(mock(SocketAddress.class)); + ClientCnx clientCnxMock = mockClientCnx(); when(clientMock.getConnection(any())).thenReturn(CompletableFuture.completedFuture(clientCnxMock)); when(clientMock.getConnection(anyString())).thenReturn(CompletableFuture.completedFuture(clientCnxMock)); when(clientMock.getConnection(anyString(), anyInt())) @@ -87,6 +93,82 @@ class ClientTestFixtures { return clientMock; } + public static ClientCnx mockClientCnx() { + ClientCnx clientCnxMock = mock(ClientCnx.class, Mockito.RETURNS_DEEP_STUBS); + ChannelHandlerContext ctx = mockChannelHandlerContext(); + doReturn(ctx).when(clientCnxMock).ctx(); + doAnswer(invocation -> { + ByteBuf buf = invocation.getArgument(0); + buf.release(); + return CompletableFuture.completedFuture(mock(ProducerResponse.class)); + }).when(clientCnxMock).sendRequestWithId(any(), anyLong()); + when(clientCnxMock.channel().remoteAddress()).thenReturn(mock(SocketAddress.class)); + return clientCnxMock; + } + + /** + * Mock a ChannelHandlerContext where write and writeAndFlush are always successful. + * This might not be suitable for all tests. + * + * @return a mocked ChannelHandlerContext + */ + public static ChannelHandlerContext mockChannelHandlerContext() { + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + + // return an empty channel mock from ctx.channel() + Channel channel = mock(Channel.class); + when(ctx.channel()).thenReturn(channel); + + // handle write and writeAndFlush methods so that the input message is released + + // create a listener future that is returned from write and writeAndFlush + // that immediately completes the listener in the calling thread as if it was successful + ChannelFuture listenerFuture = mock(ChannelFuture.class); + Future<Void> succeededFuture = createSucceededFuture(); + doAnswer(invocation -> { + GenericFutureListener<Future<Void>> listener = invocation.getArgument(0); + listener.operationComplete(succeededFuture); + return listenerFuture; + }).when(listenerFuture).addListener(any()); + + // handle write and writeAndFlush methods so that the input message is released + doAnswer(invocation -> { + Object msg = invocation.getArgument(0); + ReferenceCountUtil.release(msg); + return listenerFuture; + }).when(ctx).write(any(), any()); + doAnswer(invocation -> { + Object msg = invocation.getArgument(0); + ReferenceCountUtil.release(msg); + return listenerFuture; + }).when(ctx).writeAndFlush(any(), any()); + doAnswer(invocation -> { + Object msg = invocation.getArgument(0); + ReferenceCountUtil.release(msg); + return listenerFuture; + }).when(ctx).writeAndFlush(any()); + + return ctx; + } + + public static Future<Void> createSucceededFuture() { + EventExecutor eventExecutor = mockEventExecutor(); + // create a succeeded future that is returned from the listener, listeners will run in the calling thread + // using the mocked EventExecutor + SucceededFuture<Void> succeededFuture = new SucceededFuture<>(eventExecutor, null); + return succeededFuture; + } + + public static EventExecutor mockEventExecutor() { + // mock an EventExecutor that runs the listener in the calling thread + EventExecutor eventExecutor = mock(EventExecutor.class); + doAnswer(invocation -> { + invocation.getArgument(0, Runnable.class).run(); + return null; + }).when(eventExecutor).execute(any(Runnable.class)); + return eventExecutor; + } + static <T> CompletableFuture<T> createDelayedCompletedFuture(T result, int delayMillis) { CompletableFuture<T> future = new CompletableFuture<>(); SCHEDULER.schedule(() -> future.complete(result), delayMillis, TimeUnit.MILLISECONDS);
