JimmyWang6 commented on code in PR #20246:
URL: https://github.com/apache/kafka/pull/20246#discussion_r2461555694


##########
core/src/test/java/kafka/server/share/SharePartitionTest.java:
##########
@@ -8801,6 +8838,294 @@ public void 
testRecordArchivedWithWriteStateRPCFailure() throws InterruptedExcep
         assertEquals(2, 
sharePartition.cachedState().get(7L).batchDeliveryCount());
     }
 
+    @Test
+    public void testAcquireSingleBatchInRecordLimitMode() {
+        Persister persister = Mockito.mock(Persister.class);
+        SharePartition sharePartition = 
Mockito.spy(SharePartitionBuilder.builder()
+            .withState(SharePartitionState.ACTIVE)
+            .withDefaultAcquisitionLockTimeoutMs(ACQUISITION_LOCK_TIMEOUT_MS)
+            .withPersister(persister)
+            .build());
+        
Mockito.doReturn(true).when(sharePartition).isRecordLimitMode(Mockito.any());
+        WriteShareGroupStateResult writeShareGroupStateResult = 
Mockito.mock(WriteShareGroupStateResult.class);
+        
Mockito.when(writeShareGroupStateResult.topicsData()).thenReturn(List.of(
+            new TopicData<>(TOPIC_ID_PARTITION.topicId(), List.of(
+                PartitionFactory.newPartitionErrorData(0, Errors.NONE.code(), 
Errors.NONE.message())))));
+        
Mockito.when(persister.writeState(Mockito.any())).thenReturn(CompletableFuture.completedFuture(writeShareGroupStateResult));
+
+        // Member-1 attempts to acquire records in strict mode with a maximum 
fetch limit of 5 records.
+        MemoryRecords records = memoryRecords(10);
+        List<AcquiredRecords> acquiredRecordsList = 
fetchAcquiredRecords(sharePartition.acquire(
+            MEMBER_ID,
+            SHARE_ACQUIRE_MODE,
+            2,
+            5,
+            DEFAULT_FETCH_OFFSET,
+            fetchPartitionData(records),
+            FETCH_ISOLATION_HWM),
+            5);
+
+        assertArrayEquals(expectedAcquiredRecord(0, 4, 1).toArray(), 
acquiredRecordsList.toArray());
+        assertEquals(5, sharePartition.nextFetchOffset());
+        assertEquals(1, sharePartition.cachedState().size());
+        assertEquals(0, sharePartition.cachedState().get(0L).firstOffset());
+        assertEquals(9, sharePartition.cachedState().get(0L).lastOffset());
+        assertNotNull(sharePartition.cachedState().get(0L).offsetState());
+        assertThrows(IllegalStateException.class, () -> 
sharePartition.cachedState().get(0L).batchState());
+        assertThrows(IllegalStateException.class, () -> 
sharePartition.cachedState().get(0L).batchMemberId());
+        assertThrows(IllegalStateException.class, () -> 
sharePartition.cachedState().get(0L).batchDeliveryCount());
+
+        assertEquals(10, 
sharePartition.cachedState().get(0L).offsetState().size());
+        assertEquals(RecordState.ACQUIRED, 
sharePartition.cachedState().get(0L).offsetState().get(0L).state());
+        assertEquals(MEMBER_ID, 
sharePartition.cachedState().get(0L).offsetState().get(0L).memberId());
+        assertEquals(RecordState.ACQUIRED, 
sharePartition.cachedState().get(0L).offsetState().get(4L).state());
+        assertEquals(MEMBER_ID, 
sharePartition.cachedState().get(0L).offsetState().get(0L).memberId());
+        assertEquals(RecordState.AVAILABLE, 
sharePartition.cachedState().get(0L).offsetState().get(5L).state());
+        assertEquals(RecordState.AVAILABLE, 
sharePartition.cachedState().get(0L).offsetState().get(9L).state());
+
+
+        // Acquire the same batch with member-2. 5 records will be acquired.
+        acquiredRecordsList = fetchAcquiredRecords(sharePartition.acquire(
+            "member-2",
+            SHARE_ACQUIRE_MODE,
+            2,
+            5,
+            DEFAULT_FETCH_OFFSET,
+            fetchPartitionData(records),
+            FETCH_ISOLATION_HWM),
+            5);
+
+        List<AcquiredRecords> expectedAcquiredRecords = new 
ArrayList<>(expectedAcquiredRecord(5, 5, 1));
+        expectedAcquiredRecords.addAll(expectedAcquiredRecord(6, 6, 1));
+        expectedAcquiredRecords.addAll(expectedAcquiredRecord(7, 7, 1));
+        expectedAcquiredRecords.addAll(expectedAcquiredRecord(8, 8, 1));
+        expectedAcquiredRecords.addAll(expectedAcquiredRecord(9, 9, 1));
+
+        assertArrayEquals(expectedAcquiredRecords.toArray(), 
acquiredRecordsList.toArray());
+        assertEquals(10, sharePartition.nextFetchOffset());
+        assertEquals(1, sharePartition.cachedState().size());
+        assertEquals(0, sharePartition.cachedState().get(0L).firstOffset());
+        assertEquals(9, sharePartition.cachedState().get(0L).lastOffset());
+
+        assertEquals(RecordState.ACQUIRED, 
sharePartition.cachedState().get(0L).offsetState().get(5L).state());
+        assertEquals("member-2", 
sharePartition.cachedState().get(0L).offsetState().get(5L).memberId());
+        assertEquals(RecordState.ACQUIRED, 
sharePartition.cachedState().get(0L).offsetState().get(9L).state());
+        assertEquals("member-2", 
sharePartition.cachedState().get(0L).offsetState().get(5L).memberId());
+    }
+
+    @Test
+    public void testAcquireMultipleBatchesInRecordLimitMode() throws 
InterruptedException {
+        Persister persister = Mockito.mock(Persister.class);
+        SharePartition sharePartition = 
Mockito.spy(SharePartitionBuilder.builder()
+            .withState(SharePartitionState.ACTIVE)
+            .withDefaultAcquisitionLockTimeoutMs(ACQUISITION_LOCK_TIMEOUT_MS)
+            .withPersister(persister)
+            .build());
+        
Mockito.doReturn(true).when(sharePartition).isRecordLimitMode(Mockito.any());
+        WriteShareGroupStateResult writeShareGroupStateResult = 
Mockito.mock(WriteShareGroupStateResult.class);
+        
Mockito.when(writeShareGroupStateResult.topicsData()).thenReturn(List.of(
+            new TopicData<>(TOPIC_ID_PARTITION.topicId(), List.of(
+                PartitionFactory.newPartitionErrorData(0, Errors.NONE.code(), 
Errors.NONE.message())))));
+        
Mockito.when(persister.writeState(Mockito.any())).thenReturn(CompletableFuture.completedFuture(writeShareGroupStateResult));
+
+        // Create 3 batches of records.
+        ByteBuffer buffer = ByteBuffer.allocate(4096);
+        memoryRecordsBuilder(buffer, 10, 5).close();
+        memoryRecordsBuilder(buffer, 15, 15).close();
+        memoryRecordsBuilder(buffer, 30, 15).close();
+
+        buffer.flip();
+
+        MemoryRecords records = MemoryRecords.readableRecords(buffer);
+        // Acquire 10 records.
+        List<AcquiredRecords> acquiredRecordsList = 
fetchAcquiredRecords(sharePartition.acquire(
+            MEMBER_ID,
+            SHARE_ACQUIRE_MODE,
+            BATCH_SIZE,
+            10,
+            DEFAULT_FETCH_OFFSET,
+            fetchPartitionData(records, 10),
+            FETCH_ISOLATION_HWM),
+            10);
+
+        assertArrayEquals(expectedAcquiredRecord(10, 19, 1).toArray(), 
acquiredRecordsList.toArray());
+        assertEquals(20, sharePartition.nextFetchOffset());
+        assertEquals(1, sharePartition.cachedState().size());
+        assertEquals(10, sharePartition.cachedState().get(10L).firstOffset());
+        assertEquals(29, sharePartition.cachedState().get(10L).lastOffset());
+        assertNotNull(sharePartition.cachedState().get(10L).offsetState());
+        assertThrows(IllegalStateException.class, () -> 
sharePartition.cachedState().get(10L).batchState());
+        assertThrows(IllegalStateException.class, () -> 
sharePartition.cachedState().get(10L).batchMemberId());
+        assertThrows(IllegalStateException.class, () -> 
sharePartition.cachedState().get(10L).batchDeliveryCount());
+        assertEquals(RecordState.ACQUIRED, 
sharePartition.cachedState().get(10L).offsetState().get(19L).state());
+        assertEquals(MEMBER_ID, 
sharePartition.cachedState().get(10L).offsetState().get(19L).memberId());
+        assertEquals(RecordState.AVAILABLE, 
sharePartition.cachedState().get(10L).offsetState().get(20L).state());
+    }
+
+    @Test
+    public void testAcknowledgeInRecordLimitMode() {
+        Persister persister = Mockito.mock(Persister.class);
+        SharePartition sharePartition = 
Mockito.spy(SharePartitionBuilder.builder()
+            .withState(SharePartitionState.ACTIVE)
+            .withDefaultAcquisitionLockTimeoutMs(ACQUISITION_LOCK_TIMEOUT_MS)
+            .withPersister(persister)
+            .build());
+        
Mockito.doReturn(true).when(sharePartition).isRecordLimitMode(Mockito.any());
+        WriteShareGroupStateResult writeShareGroupStateResult = 
Mockito.mock(WriteShareGroupStateResult.class);
+        
Mockito.when(writeShareGroupStateResult.topicsData()).thenReturn(List.of(
+            new TopicData<>(TOPIC_ID_PARTITION.topicId(), List.of(
+                PartitionFactory.newPartitionErrorData(0, Errors.NONE.code(), 
Errors.NONE.message())))));
+        
Mockito.when(persister.writeState(Mockito.any())).thenReturn(CompletableFuture.completedFuture(writeShareGroupStateResult));
+
+        MemoryRecords records = memoryRecords(10);
+        // Acquire 1 records.
+        List<AcquiredRecords> acquiredRecordsList = 
fetchAcquiredRecords(sharePartition.acquire(
+            MEMBER_ID,
+            SHARE_ACQUIRE_MODE,
+            2,
+            1,
+            DEFAULT_FETCH_OFFSET,
+            fetchPartitionData(records),
+            FETCH_ISOLATION_HWM),
+            1);
+
+        assertArrayEquals(expectedAcquiredRecord(0, 0, 1).toArray(), 
acquiredRecordsList.toArray());
+        assertEquals(1, sharePartition.nextFetchOffset());
+
+        CompletableFuture<Void> ackResult = sharePartition.acknowledge(
+            MEMBER_ID,
+            List.of(new ShareAcknowledgementBatch(0, 0, List.of((byte) 1))));
+
+        assertNull(ackResult.join());
+        assertFalse(ackResult.isCompletedExceptionally());
+        assertEquals(1, sharePartition.nextFetchOffset());
+        assertEquals(1, sharePartition.startOffset());
+        assertEquals(9, sharePartition.endOffset());
+        assertEquals(1, sharePartition.cachedState().size());
+        assertNotNull(sharePartition.cachedState().get(0L).offsetState());
+        assertEquals(RecordState.AVAILABLE, 
sharePartition.cachedState().get(0L).offsetState().get(1L).state());
+
+        // Acquire 2 records.
+        acquiredRecordsList = fetchAcquiredRecords(sharePartition.acquire(
+            MEMBER_ID,
+            SHARE_ACQUIRE_MODE,
+            2,
+            2,
+            DEFAULT_FETCH_OFFSET,
+            fetchPartitionData(records),
+            FETCH_ISOLATION_HWM),
+            2);
+
+        List<AcquiredRecords> expectedAcquiredRecords = new 
ArrayList<>(expectedAcquiredRecord(1, 1, 1));
+        expectedAcquiredRecords.addAll(expectedAcquiredRecord(2, 2, 1));
+        assertArrayEquals(expectedAcquiredRecords.toArray(), 
acquiredRecordsList.toArray());
+
+        // Ack only 1 record
+        ackResult = sharePartition.acknowledge(
+            MEMBER_ID,
+            List.of(new ShareAcknowledgementBatch(1, 1, List.of((byte) 1))));
+        assertNull(ackResult.join());
+        assertFalse(ackResult.isCompletedExceptionally());
+        assertEquals(3, sharePartition.nextFetchOffset());
+        assertEquals(2, sharePartition.startOffset());
+        assertEquals(9, sharePartition.endOffset());
+        assertEquals(1, sharePartition.cachedState().size());
+        assertNotNull(sharePartition.cachedState().get(0L).offsetState());
+        assertEquals(RecordState.ACQUIRED, 
sharePartition.cachedState().get(0L).offsetState().get(2L).state());
+        assertEquals(RecordState.AVAILABLE, 
sharePartition.cachedState().get(0L).offsetState().get(3L).state());
+    }
+
+    @Test
+    public void testAcquisitionLockTimeoutInRecordLimitMode() throws 
InterruptedException {
+        Persister persister = Mockito.mock(Persister.class);
+        SharePartition sharePartition = 
Mockito.spy(SharePartitionBuilder.builder()
+            .withState(SharePartitionState.ACTIVE)
+            .withDefaultAcquisitionLockTimeoutMs(ACQUISITION_LOCK_TIMEOUT_MS)
+            .withPersister(persister)
+            .build());
+        
Mockito.doReturn(true).when(sharePartition).isRecordLimitMode(Mockito.any());
+        WriteShareGroupStateResult writeShareGroupStateResult = 
Mockito.mock(WriteShareGroupStateResult.class);
+        
Mockito.when(writeShareGroupStateResult.topicsData()).thenReturn(List.of(
+            new TopicData<>(TOPIC_ID_PARTITION.topicId(), List.of(
+                PartitionFactory.newPartitionErrorData(0, Errors.NONE.code(), 
Errors.NONE.message())))));
+        
Mockito.when(persister.writeState(Mockito.any())).thenReturn(CompletableFuture.completedFuture(writeShareGroupStateResult));
+
+        // Create 3 batches of records.
+        ByteBuffer buffer = ByteBuffer.allocate(4096);
+        memoryRecordsBuilder(buffer, 0, 5).close();
+        memoryRecordsBuilder(buffer, 5, 15).close();
+
+        buffer.flip();
+
+        MemoryRecords records = MemoryRecords.readableRecords(buffer);
+        // Acquire 3 records.
+        List<AcquiredRecords> acquiredRecordsList = 
fetchAcquiredRecords(sharePartition.acquire(
+            MEMBER_ID,
+            SHARE_ACQUIRE_MODE,
+            BATCH_SIZE,
+            2,
+            DEFAULT_FETCH_OFFSET,
+            fetchPartitionData(records, 10),
+            FETCH_ISOLATION_HWM),
+            2);
+
+        assertArrayEquals(expectedAcquiredRecord(0, 1, 1).toArray(), 
acquiredRecordsList.toArray());
+        assertThrows(IllegalStateException.class, () -> 
sharePartition.cachedState().get(0L).batchAcquisitionLockTimeoutTask());
+        // There should be 3 timer tasks for 3 offsets.

Review Comment:
   Done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to