This is an automated email from the ASF dual-hosted git repository.

jsancio pushed a commit to branch 3.6
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.6 by this push:
     new 992166b9f37 KAFKA-15345; KRaft leader notifies leadership when 
listener reaches epoch start (#14213)
992166b9f37 is described below

commit 992166b9f37883bd678358c19bb877eba26be6c8
Author: José Armando García Sancio <jsan...@users.noreply.github.com>
AuthorDate: Thu Aug 17 18:40:17 2023 -0700

    KAFKA-15345; KRaft leader notifies leadership when listener reaches epoch 
start (#14213)
    
    In a non-empty log the KRaft leader only notifies the listener of 
leadership when it has read to the leader's epoch start offset. This guarantees 
that the leader epoch has been committed and that the listener has read all 
committed offsets/records.
    
    Unfortunately, the KRaft leader doesn't do this when the log is empty. When 
the log is empty the listener is notified immediately when it has become 
leader. This makes the API inconsistent and harder to program against.
    
    This change fixes that by having the KRaft leader wait for the listener's 
nextOffset to be greater than the leader's epochStartOffset before calling 
handleLeaderChange.
    
    The RecordsBatchReader implementation is also changed to include control 
records. This makes it possible for the state machine learn about committed 
control records. This additional information can be used to compute the 
committed offset or for counting those bytes when determining when to snapshot 
the partition.
    
    Reviewers: Colin P. McCabe <cmcc...@apache.org>, Jason Gustafson 
<ja...@confluent.io>
---
 .../apache/kafka/common/record/MemoryRecords.java  |  18 ++--
 .../org/apache/kafka/raft/KafkaRaftClient.java     |  11 +-
 .../kafka/raft/internals/RecordsBatchReader.java   |  10 +-
 .../org/apache/kafka/raft/KafkaRaftClientTest.java | 112 ++++++++++++++-------
 .../apache/kafka/raft/RaftClientTestContext.java   |  18 ++--
 .../raft/internals/RecordsBatchReaderTest.java     |  57 ++++++++---
 .../kafka/raft/internals/RecordsIteratorTest.java  |  70 +++++++++++++
 7 files changed, 217 insertions(+), 79 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java 
b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
index fa18a88ca79..888bcdf2cb1 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
@@ -733,11 +733,13 @@ public class MemoryRecords extends AbstractRecords {
         return MemoryRecords.readableRecords(buffer);
     }
 
-    private static void writeLeaderChangeMessage(ByteBuffer buffer,
-                                                 long initialOffset,
-                                                 long timestamp,
-                                                 int leaderEpoch,
-                                                 LeaderChangeMessage 
leaderChangeMessage) {
+    private static void writeLeaderChangeMessage(
+        ByteBuffer buffer,
+        long initialOffset,
+        long timestamp,
+        int leaderEpoch,
+        LeaderChangeMessage leaderChangeMessage
+    ) {
         try (MemoryRecordsBuilder builder = new MemoryRecordsBuilder(
             buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE,
             TimestampType.CREATE_TIME, initialOffset, timestamp,
@@ -760,7 +762,8 @@ public class MemoryRecords extends AbstractRecords {
         return MemoryRecords.readableRecords(buffer);
     }
 
-    private static void writeSnapshotHeaderRecord(ByteBuffer buffer,
+    private static void writeSnapshotHeaderRecord(
+        ByteBuffer buffer,
         long initialOffset,
         long timestamp,
         int leaderEpoch,
@@ -788,7 +791,8 @@ public class MemoryRecords extends AbstractRecords {
         return MemoryRecords.readableRecords(buffer);
     }
 
-    private static void writeSnapshotFooterRecord(ByteBuffer buffer,
+    private static void writeSnapshotFooterRecord(
+        ByteBuffer buffer,
         long initialOffset,
         long timestamp,
         int leaderEpoch,
diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java 
b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
index 9f3e82a55b4..b2e14ee3ec9 100644
--- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
@@ -438,7 +438,6 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
         );
 
         LeaderState<T> state = quorum.transitionToLeader(endOffset, 
accumulator);
-        maybeFireLeaderChange(state);
 
         log.initializeLeaderEpoch(quorum.epoch());
 
@@ -2653,11 +2652,17 @@ public class KafkaRaftClient<T> implements 
RaftClient<T> {
         }
 
         private void maybeFireLeaderChange(LeaderAndEpoch leaderAndEpoch, long 
epochStartOffset) {
-            // If this node is becoming the leader, then we can fire 
`handleClaim` as soon
+            // If this node is becoming the leader, then we can fire 
`handleLeaderChange` as soon
             // as the listener has caught up to the start of the leader epoch. 
This guarantees
             // that the state machine has seen the full committed state before 
it becomes
             // leader and begins writing to the log.
-            if (shouldFireLeaderChange(leaderAndEpoch) && nextOffset() >= 
epochStartOffset) {
+            //
+            // Note that the raft client doesn't need to compare nextOffset 
against the high-watermark
+            // to guarantee that the listener has caught up to the 
high-watermark. This is true because
+            // the only way nextOffset can be greater than epochStartOffset is 
for the leader to have
+            // established the new high-watermark (of at least 
epochStartOffset + 1) and for the listener
+            // to have consumed up to that new high-watermark.
+            if (shouldFireLeaderChange(leaderAndEpoch) && nextOffset() > 
epochStartOffset) {
                 lastFiredLeaderChange = leaderAndEpoch;
                 listener.handleLeaderChange(leaderAndEpoch);
             }
diff --git 
a/raft/src/main/java/org/apache/kafka/raft/internals/RecordsBatchReader.java 
b/raft/src/main/java/org/apache/kafka/raft/internals/RecordsBatchReader.java
index 61819a9dcca..84e2d9fe3d0 100644
--- a/raft/src/main/java/org/apache/kafka/raft/internals/RecordsBatchReader.java
+++ b/raft/src/main/java/org/apache/kafka/raft/internals/RecordsBatchReader.java
@@ -117,14 +117,8 @@ public final class RecordsBatchReader<T> implements 
BatchReader<T> {
     }
 
     private Optional<Batch<T>> nextBatch() {
-        while (iterator.hasNext()) {
-            Batch<T> batch = iterator.next();
-
-            if (batch.records().isEmpty()) {
-                lastReturnedOffset = batch.lastOffset();
-            } else {
-                return Optional.of(batch);
-            }
+        if (iterator.hasNext()) {
+            return Optional.of(iterator.next());
         }
 
         return Optional.empty();
diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java 
b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
index 068147c5fb3..d71874a5336 100644
--- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
@@ -1191,14 +1191,14 @@ public class KafkaRaftClientTest {
         List<String> records = Arrays.asList("a", "b", "c");
         long offset = context.client.scheduleAppend(epoch, records);
         context.client.poll();
-        assertEquals(OptionalLong.empty(), 
context.listener.lastCommitOffset());
+        assertEquals(OptionalLong.of(0L), context.listener.lastCommitOffset());
 
         // Let the follower send a fetch, it should advance the high watermark
         context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 1L, 
epoch, 500));
         context.pollUntilResponse();
         context.assertSentFetchPartitionResponse(Errors.NONE, epoch, 
OptionalInt.of(localId));
         assertEquals(OptionalLong.of(1L), context.client.highWatermark());
-        assertEquals(OptionalLong.empty(), 
context.listener.lastCommitOffset());
+        assertEquals(OptionalLong.of(0L), context.listener.lastCommitOffset());
 
         // Let the follower send another fetch from offset 4
         context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 4L, 
epoch, 500));
@@ -2548,18 +2548,50 @@ public class KafkaRaftClientTest {
     }
 
     @Test
-    public void testHandleClaimFiresImmediatelyOnEmptyLog() throws Exception {
+    public void 
testHandleLeaderChangeFiresAfterListenerReachesEpochStartOffsetOnEmptyLog() 
throws Exception {
         int localId = 0;
         int otherNodeId = 1;
-        int epoch = 5;
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
 
-        RaftClientTestContext context = 
RaftClientTestContext.initializeAsLeader(localId, voters, epoch);
+        RaftClientTestContext context = new 
RaftClientTestContext.Builder(localId, voters)
+            .build();
+
+        context.becomeLeader();
+        context.client.poll();
+        int epoch = context.currentEpoch();
+
+        // After becoming leader, we expect the `LeaderChange` record to be 
appended.
+        assertEquals(1L, context.log.endOffset().offset);
+
+        // The high watermark is not known to the leader until the followers
+        // begin fetching, so we should not have fired the 
`handleLeaderChange` callback.
+        assertEquals(OptionalInt.empty(), 
context.listener.currentClaimedEpoch());
+        assertEquals(OptionalLong.empty(), 
context.listener.lastCommitOffset());
+
+        // Deliver a fetch from the other voter. The high watermark will not
+        // be exposed until it is able to reach the start of the leader epoch,
+        // so we are unable to deliver committed data or fire 
`handleLeaderChange`.
+        context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 0L, 0, 
0));
+        context.client.poll();
+        assertEquals(OptionalInt.empty(), 
context.listener.currentClaimedEpoch());
+        assertEquals(OptionalLong.empty(), 
context.listener.lastCommitOffset());
+
+        // Now catch up to the start of the leader epoch so that the high
+        // watermark advances and we can start sending committed data to the
+        // listener. Note that the `LeaderChange` control record is included
+        // in the committed batches.
+        context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 1L, 
epoch, 0));
+        context.client.poll();
+        assertEquals(OptionalLong.of(0), context.listener.lastCommitOffset());
+
+        // Poll again now that the listener has caught up to the HWM
+        context.client.poll();
         assertEquals(OptionalInt.of(epoch), 
context.listener.currentClaimedEpoch());
+        assertEquals(0, context.listener.claimedEpochStartOffset(epoch));
     }
 
     @Test
-    public void 
testHandleClaimCallbackFiresAfterHighWatermarkReachesEpochStartOffset() throws 
Exception {
+    public void 
testHandleLeaderChangeFiresAfterListenerReachesEpochStartOffset() throws 
Exception {
         int localId = 0;
         int otherNodeId = 1;
         int epoch = 5;
@@ -2585,13 +2617,13 @@ public class KafkaRaftClientTest {
         assertEquals(10L, context.log.endOffset().offset);
 
         // The high watermark is not known to the leader until the followers
-        // begin fetching, so we should not have fired the `handleClaim` 
callback.
+        // begin fetching, so we should not have fired the 
`handleLeaderChange` callback.
         assertEquals(OptionalInt.empty(), 
context.listener.currentClaimedEpoch());
         assertEquals(OptionalLong.empty(), 
context.listener.lastCommitOffset());
 
         // Deliver a fetch from the other voter. The high watermark will not
         // be exposed until it is able to reach the start of the leader epoch,
-        // so we are unable to deliver committed data or fire `handleClaim`.
+        // so we are unable to deliver committed data or fire 
`handleLeaderChange`.
         context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 3L, 1, 
500));
         context.client.poll();
         assertEquals(OptionalInt.empty(), 
context.listener.currentClaimedEpoch());
@@ -2599,26 +2631,28 @@ public class KafkaRaftClientTest {
 
         // Now catch up to the start of the leader epoch so that the high
         // watermark advances and we can start sending committed data to the
-        // listener. Note that the `LeaderChange` control record is filtered.
+        // listener. Note that the `LeaderChange` control record is included
+        // in the committed batches.
         context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 10L, 
epoch, 500));
         context.pollUntil(() -> {
-            int committedBatches = context.listener.numCommittedBatches();
-            long baseOffset = 0;
-            for (int index = 0; index < committedBatches; index++) {
-                List<String> expectedBatch = expectedBatches.get(index);
-                assertEquals(expectedBatch, 
context.listener.commitWithBaseOffset(baseOffset));
-                baseOffset += expectedBatch.size();
+            int index = 0;
+            for (Batch<String> batch : context.listener.committedBatches()) {
+                if (index < expectedBatches.size()) {
+                    // It must be a data record so compare it
+                    assertEquals(expectedBatches.get(index), batch.records());
+                }
+                index++;
             }
+            // The control record must be the last batch committed
+            assertEquals(4, index);
 
             return context.listener.currentClaimedEpoch().isPresent();
         });
 
         assertEquals(OptionalInt.of(epoch), 
context.listener.currentClaimedEpoch());
-        // Note that last committed offset is inclusive, hence we subtract 1.
-        assertEquals(
-            
OptionalLong.of(expectedBatches.stream().mapToInt(List::size).sum() - 1),
-            context.listener.lastCommitOffset()
-        );
+        // Note that last committed offset is inclusive and must include the 
leader change record.
+        assertEquals(OptionalLong.of(9), context.listener.lastCommitOffset());
+        assertEquals(9, context.listener.claimedEpochStartOffset(epoch));
     }
 
     @Test
@@ -2647,18 +2681,18 @@ public class KafkaRaftClientTest {
         context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 10L, 
epoch, 0));
         context.pollUntil(() -> 
OptionalInt.of(epoch).equals(context.listener.currentClaimedEpoch()));
         assertEquals(OptionalLong.of(10L), context.client.highWatermark());
-        assertEquals(OptionalLong.of(8L), context.listener.lastCommitOffset());
+        assertEquals(OptionalLong.of(9L), context.listener.lastCommitOffset());
         assertEquals(OptionalInt.of(epoch), 
context.listener.currentClaimedEpoch());
-        // Ensure that the `handleClaim` callback was not fired early
+        // Ensure that the `handleLeaderChange` callback was not fired early
         assertEquals(9L, context.listener.claimedEpochStartOffset(epoch));
 
         // Register a second listener and allow it to catch up to the high 
watermark
         RaftClientTestContext.MockListener secondListener = new 
RaftClientTestContext.MockListener(OptionalInt.of(localId));
         context.client.register(secondListener);
         context.pollUntil(() -> 
OptionalInt.of(epoch).equals(secondListener.currentClaimedEpoch()));
-        assertEquals(OptionalLong.of(8L), secondListener.lastCommitOffset());
+        assertEquals(OptionalLong.of(9L), secondListener.lastCommitOffset());
         assertEquals(OptionalInt.of(epoch), 
context.listener.currentClaimedEpoch());
-        // Ensure that the `handleClaim` callback was not fired early
+        // Ensure that the `handleLeaderChange` callback was not fired early
         assertEquals(9L, secondListener.claimedEpochStartOffset(epoch));
     }
 
@@ -2686,12 +2720,12 @@ public class KafkaRaftClientTest {
 
         // Let the initial listener catch up
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
-        context.pollUntil(() -> 
OptionalLong.of(8).equals(context.listener.lastCommitOffset()));
+        context.pollUntil(() -> 
OptionalLong.of(9).equals(context.listener.lastCommitOffset()));
 
         // Register a second listener
         RaftClientTestContext.MockListener secondListener = new 
RaftClientTestContext.MockListener(OptionalInt.of(localId));
         context.client.register(secondListener);
-        context.pollUntil(() -> 
OptionalLong.of(8).equals(secondListener.lastCommitOffset()));
+        context.pollUntil(() -> 
OptionalLong.of(9).equals(secondListener.lastCommitOffset()));
         context.client.unregister(secondListener);
 
         // Write to the log and show that the default listener gets updated...
@@ -2700,7 +2734,7 @@ public class KafkaRaftClientTest {
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
         context.pollUntil(() -> 
OptionalLong.of(10).equals(context.listener.lastCommitOffset()));
         // ... but unregister listener doesn't
-        assertEquals(OptionalLong.of(8), secondListener.lastCommitOffset());
+        assertEquals(OptionalLong.of(9), secondListener.lastCommitOffset());
     }
 
     @Test
@@ -2785,14 +2819,18 @@ public class KafkaRaftClientTest {
         assertEquals(OptionalLong.of(10L), context.client.highWatermark());
 
         // Register another listener and verify that it catches up while we 
remain 'voted'
-        RaftClientTestContext.MockListener secondListener = new 
RaftClientTestContext.MockListener(OptionalInt.of(localId));
+        RaftClientTestContext.MockListener secondListener = new 
RaftClientTestContext.MockListener(
+            OptionalInt.of(localId)
+        );
         context.client.register(secondListener);
         context.client.poll();
         context.assertVotedCandidate(candidateEpoch, otherNodeId);
 
-        // Note the offset is 8 because the record at offset 9 is a control 
record
-        context.pollUntil(() -> 
secondListener.lastCommitOffset().equals(OptionalLong.of(8L)));
-        assertEquals(OptionalLong.of(8L), secondListener.lastCommitOffset());
+        // Note the offset is 9 because from offsets 0 to 8 there are data 
records,
+        // at offset 9 there is a control record and the raft client sends 
control record to the
+        // raft listener
+        context.pollUntil(() -> 
secondListener.lastCommitOffset().equals(OptionalLong.of(9L)));
+        assertEquals(OptionalLong.of(9L), secondListener.lastCommitOffset());
         assertEquals(OptionalInt.empty(), 
secondListener.currentClaimedEpoch());
     }
 
@@ -2835,14 +2873,18 @@ public class KafkaRaftClientTest {
         context.assertVotedCandidate(candidateEpoch, localId);
 
         // Register another listener and verify that it catches up
-        RaftClientTestContext.MockListener secondListener = new 
RaftClientTestContext.MockListener(OptionalInt.of(localId));
+        RaftClientTestContext.MockListener secondListener = new 
RaftClientTestContext.MockListener(
+            OptionalInt.of(localId)
+        );
         context.client.register(secondListener);
         context.client.poll();
         context.assertVotedCandidate(candidateEpoch, localId);
 
-        // Note the offset is 8 because the record at offset 9 is a control 
record
-        context.pollUntil(() -> 
secondListener.lastCommitOffset().equals(OptionalLong.of(8L)));
-        assertEquals(OptionalLong.of(8L), secondListener.lastCommitOffset());
+        // Note the offset is 9 because from offsets 0 to 8 there are data 
records,
+        // at offset 9 there is a control record and the raft client sends 
control record to the
+        // raft listener
+        context.pollUntil(() -> 
secondListener.lastCommitOffset().equals(OptionalLong.of(9L)));
+        assertEquals(OptionalLong.of(9L), secondListener.lastCommitOffset());
         assertEquals(OptionalInt.empty(), 
secondListener.currentClaimedEpoch());
     }
 
diff --git 
a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java 
b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
index 9871e76133c..9d798453d0c 100644
--- a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
+++ b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
@@ -1116,6 +1116,10 @@ public final class RaftClientTestContext {
             return currentLeaderAndEpoch;
         }
 
+        List<Batch<String>> committedBatches() {
+            return commits;
+        }
+
         Batch<String> lastCommit() {
             if (commits.isEmpty()) {
                 return null;
@@ -1140,14 +1144,6 @@ public final class RaftClientTestContext {
             }
         }
 
-        List<String> commitWithBaseOffset(long baseOffset) {
-            return commits.stream()
-                .filter(batch -> batch.baseOffset() == baseOffset)
-                .findFirst()
-                .map(batch -> batch.records())
-                .orElse(null);
-        }
-
         List<String> commitWithLastOffset(long lastOffset) {
             return commits.stream()
                 .filter(batch -> batch.lastOffset() == lastOffset)
@@ -1194,14 +1190,14 @@ public final class RaftClientTestContext {
 
         @Override
         public void handleLeaderChange(LeaderAndEpoch leaderAndEpoch) {
-            // We record the next expected offset as the claimed epoch's start
+            // We record the current committed offset as the claimed epoch's 
start
             // offset. This is useful to verify that the `handleLeaderChange` 
callback
-            // was not received early.
+            // was not received early on the leader.
             this.currentLeaderAndEpoch = leaderAndEpoch;
 
             currentClaimedEpoch().ifPresent(claimedEpoch -> {
                 long claimedEpochStartOffset = lastCommitOffset().isPresent() ?
-                    lastCommitOffset().getAsLong() + 1 : 0L;
+                    lastCommitOffset().getAsLong() : 0L;
                 this.claimedEpochStartOffsets.put(leaderAndEpoch.epoch(), 
claimedEpochStartOffset);
             });
         }
diff --git 
a/raft/src/test/java/org/apache/kafka/raft/internals/RecordsBatchReaderTest.java
 
b/raft/src/test/java/org/apache/kafka/raft/internals/RecordsBatchReaderTest.java
index ae8b1dfb8e2..762fa4177a2 100644
--- 
a/raft/src/test/java/org/apache/kafka/raft/internals/RecordsBatchReaderTest.java
+++ 
b/raft/src/test/java/org/apache/kafka/raft/internals/RecordsBatchReaderTest.java
@@ -16,13 +16,17 @@
  */
 package org.apache.kafka.raft.internals;
 
+import org.apache.kafka.common.message.LeaderChangeMessage;
 import org.apache.kafka.common.record.CompressionType;
+import org.apache.kafka.common.record.ControlRecordType;
 import org.apache.kafka.common.record.FileRecords;
 import org.apache.kafka.common.record.MemoryRecords;
 import org.apache.kafka.common.record.Records;
 import org.apache.kafka.common.utils.BufferSupplier;
 import org.apache.kafka.raft.BatchReader;
+import org.apache.kafka.raft.ControlRecord;
 import org.apache.kafka.raft.internals.RecordsIteratorTest.TestBatch;
+import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.EnumSource;
 import org.mockito.Mockito;
@@ -48,9 +52,9 @@ class RecordsBatchReaderTest {
     @ParameterizedTest
     @EnumSource(CompressionType.class)
     public void testReadFromMemoryRecords(CompressionType compressionType) {
-        long baseOffset = 57;
-
-        List<TestBatch<String>> batches = 
RecordsIteratorTest.createBatches(baseOffset);
+        long seed = 57;
+        List<TestBatch<String>> batches = 
RecordsIteratorTest.createBatches(seed);
+        long baseOffset = batches.get(0).baseOffset;
         MemoryRecords memRecords = 
RecordsIteratorTest.buildRecords(compressionType, batches);
 
         testBatchReader(baseOffset, memRecords, batches);
@@ -59,9 +63,9 @@ class RecordsBatchReaderTest {
     @ParameterizedTest
     @EnumSource(CompressionType.class)
     public void testReadFromFileRecords(CompressionType compressionType) 
throws Exception {
-        long baseOffset = 57;
-
-        List<TestBatch<String>> batches = 
RecordsIteratorTest.createBatches(baseOffset);
+        long seed = 57;
+        List<TestBatch<String>> batches = 
RecordsIteratorTest.createBatches(seed);
+        long baseOffset = batches.get(0).baseOffset;
         MemoryRecords memRecords = 
RecordsIteratorTest.buildRecords(compressionType, batches);
 
         FileRecords fileRecords = FileRecords.open(tempFile());
@@ -70,6 +74,28 @@ class RecordsBatchReaderTest {
         testBatchReader(baseOffset, fileRecords, batches);
     }
 
+    @Test
+    public void testLeaderChangeControlBatch() {
+        // Confirm that the RecordsBatchReader is able to iterate over control 
batches
+        MemoryRecords records = 
RecordsIteratorTest.buildControlRecords(ControlRecordType.LEADER_CHANGE);
+        ControlRecord expectedRecord = new 
ControlRecord(ControlRecordType.LEADER_CHANGE, new LeaderChangeMessage());
+
+        try (RecordsBatchReader<String> reader = RecordsBatchReader.of(
+                0,
+                records,
+                serde,
+                BufferSupplier.NO_CACHING,
+                MAX_BATCH_BYTES,
+                ignore -> { },
+                true
+            )
+        ) {
+            assertTrue(reader.hasNext());
+            assertEquals(Collections.singletonList(expectedRecord), 
reader.next().controlRecords());
+            assertFalse(reader.hasNext());
+        }
+    }
+
     private void testBatchReader(
         long baseOffset,
         Records records,
@@ -103,18 +129,19 @@ class RecordsBatchReaderTest {
             closeListener,
             true
         );
-
-        for (TestBatch<String> batch : expectedBatches) {
-            assertTrue(reader.hasNext());
-            assertEquals(batch, TestBatch.from(reader.next()));
+        try {
+            for (TestBatch<String> batch : expectedBatches) {
+                assertTrue(reader.hasNext());
+                assertEquals(batch, TestBatch.from(reader.next()));
+            }
+
+            assertFalse(reader.hasNext());
+            assertThrows(NoSuchElementException.class, reader::next);
+        } finally {
+            reader.close();
         }
 
-        assertFalse(reader.hasNext());
-        assertThrows(NoSuchElementException.class, reader::next);
-
-        reader.close();
         Mockito.verify(closeListener).onClose(reader);
         assertEquals(Collections.emptySet(), allocatedBuffers);
     }
-
 }
diff --git 
a/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java 
b/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java
index 67f16c9ac8f..9433dbe1a96 100644
--- 
a/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java
+++ 
b/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java
@@ -34,8 +34,10 @@ import net.jqwik.api.ForAll;
 import net.jqwik.api.Property;
 import org.apache.kafka.common.errors.CorruptRecordException;
 import org.apache.kafka.common.memory.MemoryPool;
+import org.apache.kafka.common.message.LeaderChangeMessage;
 import org.apache.kafka.common.message.SnapshotFooterRecord;
 import org.apache.kafka.common.message.SnapshotHeaderRecord;
+import org.apache.kafka.common.protocol.ApiMessage;
 import org.apache.kafka.common.record.CompressionType;
 import org.apache.kafka.common.record.ControlRecordType;
 import org.apache.kafka.common.record.DefaultRecordBatch;
@@ -45,6 +47,7 @@ import org.apache.kafka.common.record.Records;
 import org.apache.kafka.common.utils.BufferSupplier;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.raft.Batch;
+import org.apache.kafka.raft.ControlRecord;
 import org.apache.kafka.raft.OffsetAndEpoch;
 import org.apache.kafka.server.common.serialization.RecordSerde;
 import org.apache.kafka.snapshot.MockRawSnapshotWriter;
@@ -53,6 +56,7 @@ import org.apache.kafka.test.TestUtils;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.EnumSource;
 import org.junit.jupiter.params.provider.MethodSource;
 import org.mockito.Mockito;
 import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
@@ -188,6 +192,35 @@ public final class RecordsIteratorTest {
         }
     }
 
+    @ParameterizedTest
+    @EnumSource(value = ControlRecordType.class, names = {"LEADER_CHANGE", 
"SNAPSHOT_HEADER", "SNAPSHOT_FOOTER"})
+    void testWithAllSupportedControlRecords(ControlRecordType type) {
+        MemoryRecords records = buildControlRecords(type);
+        final ApiMessage expectedMessage;
+        switch (type) {
+            case LEADER_CHANGE:
+                expectedMessage = new LeaderChangeMessage();
+                break;
+            case SNAPSHOT_HEADER:
+                expectedMessage = new SnapshotHeaderRecord();
+                break;
+            case SNAPSHOT_FOOTER:
+                expectedMessage = new SnapshotFooterRecord();
+                break;
+            default:
+                throw new RuntimeException("Should not happen. Poorly 
configured test");
+        }
+
+        try (RecordsIterator<String> iterator = createIterator(records, 
BufferSupplier.NO_CACHING, true)) {
+            assertTrue(iterator.hasNext());
+            assertEquals(
+                Collections.singletonList(new ControlRecord(type, 
expectedMessage)),
+                iterator.next().controlRecords()
+            );
+            assertFalse(iterator.hasNext());
+        }
+    }
+
     @Test
     void testControlRecordTypeValues() {
         // If this test fails then it means that ControlRecordType was 
changed. Please review the
@@ -274,6 +307,43 @@ public final class RecordsIteratorTest {
         return batches;
     }
 
+    public static MemoryRecords buildControlRecords(ControlRecordType type) {
+        final MemoryRecords records;
+        switch (type) {
+            case LEADER_CHANGE:
+                records = MemoryRecords.withLeaderChangeMessage(
+                    0,
+                    0,
+                    1,
+                    ByteBuffer.allocate(128),
+                    new LeaderChangeMessage()
+                );
+                break;
+            case SNAPSHOT_HEADER:
+                records = MemoryRecords.withSnapshotHeaderRecord(
+                    0,
+                    0,
+                    1,
+                    ByteBuffer.allocate(128),
+                    new SnapshotHeaderRecord()
+                );
+                break;
+            case SNAPSHOT_FOOTER:
+                records = MemoryRecords.withSnapshotFooterRecord(
+                    0,
+                    0,
+                    1,
+                    ByteBuffer.allocate(128),
+                    new SnapshotFooterRecord()
+                );
+                break;
+            default:
+                throw new RuntimeException(String.format("Control record type 
%s is not supported", type));
+        }
+
+        return records;
+    }
+
     public static MemoryRecords buildRecords(
         CompressionType compressionType,
         List<TestBatch<String>> batches

Reply via email to