brandboat commented on code in PR #18105:
URL: https://github.com/apache/kafka/pull/18105#discussion_r1881940289


##########
core/src/test/java/kafka/test/api/ShareConsumerTest.java:
##########
@@ -1069,100 +1055,50 @@ public void 
testMultipleConsumersInMultipleGroupsConcurrentConsumption(String pe
         alterShareAutoOffsetReset(groupId2, "earliest");
         alterShareAutoOffsetReset(groupId3, "earliest");
 
-        ExecutorService producerExecutorService = 
Executors.newFixedThreadPool(producerCount);
-        ExecutorService shareGroupExecutorService1 = 
Executors.newFixedThreadPool(consumerCount);
-        ExecutorService shareGroupExecutorService2 = 
Executors.newFixedThreadPool(consumerCount);
-        ExecutorService shareGroupExecutorService3 = 
Executors.newFixedThreadPool(consumerCount);
-
-        CountDownLatch startSignal = new CountDownLatch(producerCount);
-
-        ConcurrentLinkedQueue<CompletableFuture<Integer>> producerFutures = 
new ConcurrentLinkedQueue<>();
-
+        List<CompletableFuture<Integer>> producerFutures = new ArrayList<>();
         for (int i = 0; i < producerCount; i++) {
-            producerExecutorService.submit(() -> {
-                CompletableFuture<Integer> future = 
produceMessages(messagesPerProducer);
-                producerFutures.add(future);
-                startSignal.countDown();
-            });
+            producerFutures.add(CompletableFuture.supplyAsync(() -> 
produceMessages(messagesPerProducer)));
         }
-
-        ConcurrentLinkedQueue<CompletableFuture<Integer>> futures1 = new 
ConcurrentLinkedQueue<>();
-        ConcurrentLinkedQueue<CompletableFuture<Integer>> futures2 = new 
ConcurrentLinkedQueue<>();
-        ConcurrentLinkedQueue<CompletableFuture<Integer>> futures3 = new 
ConcurrentLinkedQueue<>();
-
         // Wait for the producers to run
-        try {
-            boolean signalled = startSignal.await(15, TimeUnit.SECONDS);
-            assertTrue(signalled);
-        } catch (InterruptedException e) {
-            fail("Exception awaiting start signal");
-        }
+        assertDoesNotThrow(() -> 
CompletableFuture.allOf(producerFutures.toArray(CompletableFuture[]::new))
+                .get(15, TimeUnit.SECONDS), "Exception awaiting 
produceMessages");
+        int actualMessageSent = 
producerFutures.stream().map(CompletableFuture::join).reduce(Integer::sum).orElse(0);
 
-        int maxBytes = 100000;
+        List<CompletableFuture<Integer>> consumeMessagesFutures1 = new 
ArrayList<>();
+        List<CompletableFuture<Integer>> consumeMessagesFutures2 = new 
ArrayList<>();
+        List<CompletableFuture<Integer>> consumeMessagesFutures3 = new 
ArrayList<>();
 
+        int maxBytes = 100000;
         for (int i = 0; i < consumerCount; i++) {
             final int consumerNumber = i + 1;
-            shareGroupExecutorService1.submit(() -> {
-                CompletableFuture<Integer> future = new CompletableFuture<>();
-                futures1.add(future);
-                consumeMessages(totalMessagesConsumedGroup1, 
totalMessagesSent, "group1", consumerNumber, 100, true, future, maxBytes);
-            });
-            shareGroupExecutorService2.submit(() -> {
-                CompletableFuture<Integer> future = new CompletableFuture<>();
-                futures2.add(future);
-                consumeMessages(totalMessagesConsumedGroup2, 
totalMessagesSent, "group2", consumerNumber, 100, true, future, maxBytes);
-            });
-            shareGroupExecutorService3.submit(() -> {
-                CompletableFuture<Integer> future = new CompletableFuture<>();
-                futures3.add(future);
-                consumeMessages(totalMessagesConsumedGroup3, 
totalMessagesSent, "group3", consumerNumber, 100, true, future, maxBytes);
-            });
-        }
-        producerExecutorService.shutdown();
-        shareGroupExecutorService1.shutdown();
-        shareGroupExecutorService2.shutdown();
-        shareGroupExecutorService3.shutdown();
-        try {
-            shareGroupExecutorService1.awaitTermination(120, 
TimeUnit.SECONDS); // Wait for all consumer threads for group 1 to complete
-            shareGroupExecutorService2.awaitTermination(120, 
TimeUnit.SECONDS); // Wait for all consumer threads for group 2 to complete
-            shareGroupExecutorService3.awaitTermination(120, 
TimeUnit.SECONDS); // Wait for all consumer threads for group 3 to complete
-
-            int totalResult1 = 0;
-            for (CompletableFuture<Integer> future : futures1) {
-                totalResult1 += future.get();
-            }
+            consumeMessagesFutures1.add(CompletableFuture.supplyAsync(() ->
+                    consumeMessages(totalMessagesConsumedGroup1, 
totalMessagesSent,
+                            "group1", consumerNumber, 100, true, maxBytes)));
 
-            int totalResult2 = 0;
-            for (CompletableFuture<Integer> future : futures2) {
-                totalResult2 += future.get();
-            }
+            consumeMessagesFutures2.add(CompletableFuture.supplyAsync(() ->
+                    consumeMessages(totalMessagesConsumedGroup2, 
totalMessagesSent,
+                            "group2", consumerNumber, 100, true, maxBytes)));
 
-            int totalResult3 = 0;
-            for (CompletableFuture<Integer> future : futures3) {
-                totalResult3 += future.get();
-            }
+            consumeMessagesFutures3.add(CompletableFuture.supplyAsync(() ->
+                    consumeMessages(totalMessagesConsumedGroup3, 
totalMessagesSent,
+                            "group3", consumerNumber, 100, true, maxBytes)));
+        }
 
-            assertEquals(totalMessagesSent, totalMessagesConsumedGroup1.get());
-            assertEquals(totalMessagesSent, totalMessagesConsumedGroup2.get());
-            assertEquals(totalMessagesSent, totalMessagesConsumedGroup3.get());
-            assertEquals(totalMessagesSent, totalResult1);
-            assertEquals(totalMessagesSent, totalResult2);
-            assertEquals(totalMessagesSent, totalResult3);
+        CompletableFuture.allOf(Stream.of(consumeMessagesFutures1.stream(), 
consumeMessagesFutures2.stream(),
+                        
consumeMessagesFutures3.stream()).flatMap(Function.identity()).toArray(CompletableFuture[]::new))
+                .get(120, TimeUnit.SECONDS);
 
-            int actualMessagesSent = 0;
-            try {
-                producerExecutorService.awaitTermination(60, 
TimeUnit.SECONDS); // Wait for all producer threads to complete
+        int totalResult1 = 
consumeMessagesFutures1.stream().map(CompletableFuture::join).reduce(Integer::sum).orElse(0);
+        int totalResult2 = 
consumeMessagesFutures2.stream().map(CompletableFuture::join).reduce(Integer::sum).orElse(0);
+        int totalResult3 = 
consumeMessagesFutures3.stream().map(CompletableFuture::join).reduce(Integer::sum).orElse(0);
 
-                for (CompletableFuture<Integer> future : producerFutures) {
-                    actualMessagesSent += future.get();
-                }
-            } catch (Exception e) {
-                fail("Exception occurred : " + e.getMessage());
-            }
-            assertEquals(totalMessagesSent, actualMessagesSent);
-        } catch (Exception e) {
-            fail("Exception occurred : " + e.getMessage());
-        }
+        assertEquals(totalMessagesSent, totalMessagesConsumedGroup1.get());

Review Comment:
   Perhaps we can remove all AtomicInteger here, like you said, now we can rely 
on the returned values to verify how much records are consumed. WDYT ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to