This is an automated email from the ASF dual-hosted git repository.

schofielaj pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 05b2601dde1 KAFKA-19456: State and leader epoch should not be updated 
on writes. (#20079)
05b2601dde1 is described below

commit 05b2601dde14d52da3a14da827dd654f15934e93
Author: Sushant Mahajan <[email protected]>
AuthorDate: Wed Jul 2 00:27:57 2025 +0530

    KAFKA-19456: State and leader epoch should not be updated on writes. 
(#20079)
    
    * If a write request with higher state than seen so far for a
    specific share partition arrives at the share coordinator, the code will
    create a new share snapshot and also update the internal view of the
    state epoch.
    * For writes with higher leader epoch, the current records are updated
    with that value as well.
    * The above is not the expected behavior and only initialize RPCs should
    set and alter the state epoch and read RPC should set the leader epoch.
    * This PR rectifies the behavior.
    * Few tests have been removed.
    
    Reviewers: Andrew Schofield <[email protected]>
---
 .../coordinator/share/ShareCoordinatorShard.java   |  31 +--
 .../share/ShareCoordinatorShardTest.java           | 241 ++++-----------------
 2 files changed, 57 insertions(+), 215 deletions(-)

diff --git 
a/share-coordinator/src/main/java/org/apache/kafka/coordinator/share/ShareCoordinatorShard.java
 
b/share-coordinator/src/main/java/org/apache/kafka/coordinator/share/ShareCoordinatorShard.java
index 86b2d506376..ba47582fe16 100644
--- 
a/share-coordinator/src/main/java/org/apache/kafka/coordinator/share/ShareCoordinatorShard.java
+++ 
b/share-coordinator/src/main/java/org/apache/kafka/coordinator/share/ShareCoordinatorShard.java
@@ -328,7 +328,7 @@ public class ShareCoordinatorShard implements 
CoordinatorShard<CoordinatorRecord
         WriteShareGroupStateRequestData.PartitionData partitionData = 
topicData.partitions().get(0);
         SharePartitionKey key = 
SharePartitionKey.getInstance(request.groupId(), topicData.topicId(), 
partitionData.partition());
 
-        CoordinatorRecord record = generateShareStateRecord(partitionData, 
key);
+        CoordinatorRecord record = generateShareStateRecord(partitionData, 
key, false);
         // build successful response if record is correctly created
         WriteShareGroupStateResponseData responseData = new 
WriteShareGroupStateResponseData().setResults(
             
List.of(WriteShareGroupStateResponse.toResponseWriteStateResult(key.topicId(),
@@ -405,7 +405,7 @@ public class ShareCoordinatorShard implements 
CoordinatorShard<CoordinatorRecord
             
.setStartOffset(responseData.results().get(0).partitions().get(0).startOffset())
             
.setStateEpoch(responseData.results().get(0).partitions().get(0).stateEpoch());
 
-        CoordinatorRecord record = 
generateShareStateRecord(writePartitionData, key);
+        CoordinatorRecord record = 
generateShareStateRecord(writePartitionData, key, true);
         return new CoordinatorResult<>(List.of(record), responseData);
     }
 
@@ -637,20 +637,27 @@ public class ShareCoordinatorShard implements 
CoordinatorShard<CoordinatorRecord
      * seen so far => create a new ShareSnapshot record else create a new 
ShareUpdate record. This method assumes
      * that share partition key is present in shareStateMap since it should be 
called on initialized share partitions.
      *
-     * @param partitionData - Represents the data which should be written into 
the share state record.
-     * @param key           - The {@link SharePartitionKey} object.
+     * @param partitionData     - Represents the data which should be written 
into the share state record.
+     * @param key               - The {@link SharePartitionKey} object.
+     * @param updateLeaderEpoch - Should the leader epoch be updated, if 
higher.
      * @return {@link CoordinatorRecord} representing ShareSnapshot or 
ShareUpdate
      */
     private CoordinatorRecord generateShareStateRecord(
         WriteShareGroupStateRequestData.PartitionData partitionData,
-        SharePartitionKey key
+        SharePartitionKey key,
+        boolean updateLeaderEpoch
     ) {
         long timestamp = time.milliseconds();
         int updatesPerSnapshotLimit = 
config.shareCoordinatorSnapshotUpdateRecordsPerSnapshot();
-        if (snapshotUpdateCount.getOrDefault(key, 0) >= 
updatesPerSnapshotLimit || partitionData.stateEpoch() > 
shareStateMap.get(key).stateEpoch()) {
-            ShareGroupOffset currentState = shareStateMap.get(key); // 
shareStateMap will have the entry as containsKey is true
-            int newLeaderEpoch = partitionData.leaderEpoch() == -1 ? 
currentState.leaderEpoch() : partitionData.leaderEpoch();
-            int newStateEpoch = partitionData.stateEpoch() == -1 ? 
currentState.stateEpoch() : partitionData.stateEpoch();
+        ShareGroupOffset currentState = shareStateMap.get(key); // This method 
assumes containsKey is true.
+
+        int newLeaderEpoch = currentState.leaderEpoch();
+        if (updateLeaderEpoch) {
+            newLeaderEpoch = partitionData.leaderEpoch() != -1 ? 
partitionData.leaderEpoch() : newLeaderEpoch;
+        }
+
+        if (snapshotUpdateCount.getOrDefault(key, 0) >= 
updatesPerSnapshotLimit) {
+            // shareStateMap will have the entry as containsKey is true
             long newStartOffset = partitionData.startOffset() == -1 ? 
currentState.startOffset() : partitionData.startOffset();
 
             // Since the number of update records for this share part key 
exceeds snapshotUpdateRecordsPerSnapshot
@@ -662,14 +669,12 @@ public class ShareCoordinatorShard implements 
CoordinatorShard<CoordinatorRecord
                     .setSnapshotEpoch(currentState.snapshotEpoch() + 1)   // 
We must increment snapshot epoch as this is new snapshot.
                     .setStartOffset(newStartOffset)
                     .setLeaderEpoch(newLeaderEpoch)
-                    .setStateEpoch(newStateEpoch)
+                    .setStateEpoch(currentState.stateEpoch())
                     .setStateBatches(mergeBatches(currentState.stateBatches(), 
partitionData, newStartOffset))
                     .setCreateTimestamp(timestamp)
                     .setWriteTimestamp(timestamp)
                     .build());
         } else {
-            ShareGroupOffset currentState = shareStateMap.get(key); // 
shareStateMap will have the entry as containsKey is true.
-
             // Share snapshot is present and number of share snapshot update 
records < snapshotUpdateRecordsPerSnapshot
             // so create a share update record.
             // The incoming partition data could have overlapping state 
batches, we must merge them.
@@ -678,7 +683,7 @@ public class ShareCoordinatorShard implements 
CoordinatorShard<CoordinatorRecord
                 new ShareGroupOffset.Builder()
                     .setSnapshotEpoch(currentState.snapshotEpoch()) // Use 
same snapshotEpoch as last share snapshot.
                     .setStartOffset(partitionData.startOffset())
-                    .setLeaderEpoch(partitionData.leaderEpoch())
+                    .setLeaderEpoch(newLeaderEpoch)
                     .setStateBatches(mergeBatches(List.of(), partitionData))
                     .build());
         }
diff --git 
a/share-coordinator/src/test/java/org/apache/kafka/coordinator/share/ShareCoordinatorShardTest.java
 
b/share-coordinator/src/test/java/org/apache/kafka/coordinator/share/ShareCoordinatorShardTest.java
index 9aed6583f1d..ddd0a566d6b 100644
--- 
a/share-coordinator/src/test/java/org/apache/kafka/coordinator/share/ShareCoordinatorShardTest.java
+++ 
b/share-coordinator/src/test/java/org/apache/kafka/coordinator/share/ShareCoordinatorShardTest.java
@@ -68,6 +68,7 @@ import java.util.Set;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
 import static org.junit.jupiter.api.Assertions.assertNotNull;
 import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -143,6 +144,21 @@ class ShareCoordinatorShardTest {
     }
 
     private void writeAndReplayRecord(ShareCoordinatorShard shard, int 
leaderEpoch) {
+        // Read is necessary before write!
+        ReadShareGroupStateRequestData readRequest = new 
ReadShareGroupStateRequestData()
+            .setGroupId(GROUP_ID)
+            .setTopics(List.of(new 
ReadShareGroupStateRequestData.ReadStateData()
+                .setTopicId(TOPIC_ID)
+                .setPartitions(List.of(new 
ReadShareGroupStateRequestData.PartitionData()
+                    .setPartition(PARTITION)
+                    .setLeaderEpoch(leaderEpoch)
+                ))
+            ));
+
+        CoordinatorResult<ReadShareGroupStateResponseData, CoordinatorRecord> 
result = shard.readStateAndMaybeUpdateLeaderEpoch(readRequest);
+
+        shard.replay(0L, 0L, (short) 0, result.records().get(0));
+
         WriteShareGroupStateRequestData request = new 
WriteShareGroupStateRequestData()
             .setGroupId(GROUP_ID)
             .setTopics(List.of(new 
WriteShareGroupStateRequestData.WriteStateData()
@@ -158,9 +174,9 @@ class ShareCoordinatorShardTest {
                         .setDeliveryCount((short) 1)
                         .setDeliveryState((byte) 0)))))));
 
-        CoordinatorResult<WriteShareGroupStateResponseData, CoordinatorRecord> 
result = shard.writeState(request);
+        CoordinatorResult<WriteShareGroupStateResponseData, CoordinatorRecord> 
result2 = shard.writeState(request);
 
-        shard.replay(0L, 0L, (short) 0, result.records().get(0));
+        shard.replay(0L, 0L, (short) 0, result2.records().get(0));
     }
 
     private ShareCoordinatorShard shard;
@@ -269,6 +285,7 @@ class ShareCoordinatorShardTest {
     @Test
     public void testWriteStateSuccess() {
         initSharePartition(shard, SHARE_PARTITION_KEY);
+        writeAndReplayRecord(shard, 0);
 
         WriteShareGroupStateRequestData request = new 
WriteShareGroupStateRequestData()
             .setGroupId(GROUP_ID)
@@ -301,133 +318,13 @@ class ShareCoordinatorShardTest {
             GROUP_ID, TOPIC_ID, PARTITION, 
ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 
TIME.milliseconds())
         ).value().message()), 
shard.getShareStateMapValue(SHARE_PARTITION_KEY));
         assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY));
-        
verify(shard.getMetricsShard()).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME);
-    }
-
-    @Test
-    public void 
testWriteStateSequentialRequestsWithHigherStateEpochCreateShareSnapshots() {
-        // Makes 3 requests. First 2 with same state epoch, and 3rd with 
incremented state epoch.
-        // The test config defines number of updates/snapshot as 50. So, this 
test proves that
-        // a higher state epoch in a request forces snapshot creation, even if 
number of share updates
-        // have not breached the updates/snapshots limit.
-
-        int stateEpoch = 1;
-        int snapshotEpoch = 0;
-
-        initSharePartition(shard, SHARE_PARTITION_KEY);
-
-        WriteShareGroupStateRequestData request = new 
WriteShareGroupStateRequestData()
-            .setGroupId(GROUP_ID)
-            .setTopics(List.of(new 
WriteShareGroupStateRequestData.WriteStateData()
-                .setTopicId(TOPIC_ID)
-                .setPartitions(List.of(new 
WriteShareGroupStateRequestData.PartitionData()
-                    .setPartition(PARTITION)
-                    .setStartOffset(0)
-                    .setStateEpoch(stateEpoch)
-                    .setLeaderEpoch(0)
-                    .setStateBatches(List.of(new 
WriteShareGroupStateRequestData.StateBatch()
-                        .setFirstOffset(0)
-                        .setLastOffset(10)
-                        .setDeliveryCount((short) 1)
-                        .setDeliveryState((byte) 0)))))));
-
-        CoordinatorResult<WriteShareGroupStateResponseData, CoordinatorRecord> 
result = shard.writeState(request);
-
-        shard.replay(0L, 0L, (short) 0, result.records().get(0));
-
-        snapshotEpoch++;    // Since state epoch increased.
-        WriteShareGroupStateResponseData expectedData = 
WriteShareGroupStateResponse.toResponseData(TOPIC_ID, PARTITION);
-        List<CoordinatorRecord> expectedRecords = 
List.of(ShareCoordinatorRecordHelpers.newShareSnapshotRecord(
-            GROUP_ID, TOPIC_ID, PARTITION, 
ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 
snapshotEpoch, TIME.milliseconds())
-        ));
-
-        assertEquals(1, 
shard.getShareStateMapValue(SHARE_PARTITION_KEY).snapshotEpoch());
-        assertEquals(expectedData, result.response());
-        assertEquals(expectedRecords, result.records());
-
-        
assertEquals(groupOffset(ShareCoordinatorRecordHelpers.newShareSnapshotRecord(
-            GROUP_ID, TOPIC_ID, PARTITION, 
ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 
snapshotEpoch, TIME.milliseconds())
-        ).value().message()), 
shard.getShareStateMapValue(SHARE_PARTITION_KEY));
-        assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY));
-        
verify(shard.getMetricsShard()).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME);
-
-        // State epoch stays same so share update.
-        request = new WriteShareGroupStateRequestData()
-            .setGroupId(GROUP_ID)
-            .setTopics(List.of(new 
WriteShareGroupStateRequestData.WriteStateData()
-                .setTopicId(TOPIC_ID)
-                .setPartitions(List.of(new 
WriteShareGroupStateRequestData.PartitionData()
-                    .setPartition(PARTITION)
-                    .setStartOffset(0)
-                    .setStateEpoch(stateEpoch)
-                    .setLeaderEpoch(0)
-                    .setStateBatches(List.of(new 
WriteShareGroupStateRequestData.StateBatch()
-                        .setFirstOffset(0)
-                        .setLastOffset(10)
-                        .setDeliveryCount((short) 2)
-                        .setDeliveryState((byte) 0)))))));
-
-        result = shard.writeState(request);
-
-        shard.replay(0L, 0L, (short) 0, result.records().get(0));
-
-        expectedData = WriteShareGroupStateResponse.toResponseData(TOPIC_ID, 
PARTITION);
-        expectedRecords = 
List.of(ShareCoordinatorRecordHelpers.newShareUpdateRecord(
-            GROUP_ID, TOPIC_ID, PARTITION, 
ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 
snapshotEpoch, TIME.milliseconds())
-        ));
-
-        // Snapshot epoch did not increase
-        assertEquals(1, 
shard.getShareStateMapValue(SHARE_PARTITION_KEY).snapshotEpoch());
-        assertEquals(expectedData, result.response());
-        assertEquals(expectedRecords, result.records());
-
-        
assertEquals(groupOffset(ShareCoordinatorRecordHelpers.newShareSnapshotRecord(
-            GROUP_ID, TOPIC_ID, PARTITION, 
ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 
snapshotEpoch, TIME.milliseconds())
-        ).value().message()), 
shard.getShareStateMapValue(SHARE_PARTITION_KEY));
-        assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY));
-        verify(shard.getMetricsShard(), 
times(2)).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME);
-
-        // State epoch incremented so share snapshot.
-        request = new WriteShareGroupStateRequestData()
-            .setGroupId(GROUP_ID)
-            .setTopics(List.of(new 
WriteShareGroupStateRequestData.WriteStateData()
-                .setTopicId(TOPIC_ID)
-                .setPartitions(List.of(new 
WriteShareGroupStateRequestData.PartitionData()
-                    .setPartition(PARTITION)
-                    .setStartOffset(0)
-                    .setStateEpoch(stateEpoch + 1)   // incremented
-                    .setLeaderEpoch(0)
-                    .setStateBatches(List.of(new 
WriteShareGroupStateRequestData.StateBatch()
-                        .setFirstOffset(0)
-                        .setLastOffset(10)
-                        .setDeliveryCount((short) 2)
-                        .setDeliveryState((byte) 0)))))));
-
-        result = shard.writeState(request);
-
-        shard.replay(0L, 0L, (short) 0, result.records().get(0));
-
-        snapshotEpoch++;    // Since state epoch increased
-        expectedData = WriteShareGroupStateResponse.toResponseData(TOPIC_ID, 
PARTITION);
-        expectedRecords = 
List.of(ShareCoordinatorRecordHelpers.newShareSnapshotRecord(
-            GROUP_ID, TOPIC_ID, PARTITION, 
ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 
snapshotEpoch, TIME.milliseconds())
-        ));
-
-        // Snapshot epoch increased.
-        assertEquals(2, 
shard.getShareStateMapValue(SHARE_PARTITION_KEY).snapshotEpoch());
-        assertEquals(expectedData, result.response());
-        assertEquals(expectedRecords, result.records());
-
-        
assertEquals(groupOffset(ShareCoordinatorRecordHelpers.newShareSnapshotRecord(
-            GROUP_ID, TOPIC_ID, PARTITION, 
ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 
snapshotEpoch, TIME.milliseconds())
-        ).value().message()), 
shard.getShareStateMapValue(SHARE_PARTITION_KEY));
-        assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY));
         verify(shard.getMetricsShard(), 
times(3)).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME);
     }
 
     @Test
     public void testSubsequentWriteStateSnapshotEpochUpdatesSuccessfully() {
         initSharePartition(shard, SHARE_PARTITION_KEY);
+        writeAndReplayRecord(shard, 0);
 
         WriteShareGroupStateRequestData request1 = new 
WriteShareGroupStateRequestData()
             .setGroupId(GROUP_ID)
@@ -567,6 +464,7 @@ class ShareCoordinatorShardTest {
     @Test
     public void testWriteStateFencedLeaderEpochError() {
         initSharePartition(shard, SHARE_PARTITION_KEY);
+        writeAndReplayRecord(shard, 1);
 
         WriteShareGroupStateRequestData request1 = new 
WriteShareGroupStateRequestData()
             .setGroupId(GROUP_ID)
@@ -576,78 +474,29 @@ class ShareCoordinatorShardTest {
                     .setPartition(PARTITION)
                     .setStartOffset(0)
                     .setStateEpoch(0)
-                    .setLeaderEpoch(5)
+                    .setLeaderEpoch(0)
                     .setStateBatches(List.of(new 
WriteShareGroupStateRequestData.StateBatch()
                         .setFirstOffset(0)
                         .setLastOffset(10)
                         .setDeliveryCount((short) 1)
                         .setDeliveryState((byte) 0)))))));
 
-        WriteShareGroupStateRequestData request2 = new 
WriteShareGroupStateRequestData()
-            .setGroupId(GROUP_ID)
-            .setTopics(List.of(new 
WriteShareGroupStateRequestData.WriteStateData()
-                .setTopicId(TOPIC_ID)
-                .setPartitions(List.of(new 
WriteShareGroupStateRequestData.PartitionData()
-                    .setPartition(PARTITION)
-                    .setStartOffset(0)
-                    .setStateEpoch(0)
-                    .setLeaderEpoch(3) // Lower leader epoch in the second 
request.
-                    .setStateBatches(List.of(new 
WriteShareGroupStateRequestData.StateBatch()
-                        .setFirstOffset(11)
-                        .setLastOffset(20)
-                        .setDeliveryCount((short) 1)
-                        .setDeliveryState((byte) 0)))))));
-
         CoordinatorResult<WriteShareGroupStateResponseData, CoordinatorRecord> 
result = shard.writeState(request1);
 
-        shard.replay(0L, 0L, (short) 0, result.records().get(0));
-
-        WriteShareGroupStateResponseData expectedData = 
WriteShareGroupStateResponse.toResponseData(TOPIC_ID, PARTITION);
-        List<CoordinatorRecord> expectedRecords = 
List.of(ShareCoordinatorRecordHelpers.newShareUpdateRecord(
-            GROUP_ID, TOPIC_ID, PARTITION, 
ShareGroupOffset.fromRequest(request1.topics().get(0).partitions().get(0), 
TIME.milliseconds())
-        ));
-
-        assertEquals(expectedData, result.response());
-        assertEquals(expectedRecords, result.records());
-
-        assertEquals(groupOffset(expectedRecords.get(0).value().message()),
-            shard.getShareStateMapValue(SHARE_PARTITION_KEY));
-        assertEquals(5, shard.getLeaderMapValue(SHARE_PARTITION_KEY));
-
-        result = shard.writeState(request2);
-
-        // Since the leader epoch in the second request was lower than the one 
in the first request, FENCED_LEADER_EPOCH error is expected.
-        expectedData = WriteShareGroupStateResponse.toErrorResponseData(
+        WriteShareGroupStateResponseData expectedData = 
WriteShareGroupStateResponse.toErrorResponseData(
             TOPIC_ID, PARTITION, Errors.FENCED_LEADER_EPOCH, 
Errors.FENCED_LEADER_EPOCH.message());
-        expectedRecords = List.of();
+        List<CoordinatorRecord> expectedRecords = List.of();
 
         assertEquals(expectedData, result.response());
         assertEquals(expectedRecords, result.records());
-
-        // No changes to the leaderMap.
-        assertEquals(5, shard.getLeaderMapValue(SHARE_PARTITION_KEY));
+        assertEquals(1, shard.getLeaderMapValue(SHARE_PARTITION_KEY));
     }
 
     @Test
     public void testWriteStateFencedStateEpochError() {
-        initSharePartition(shard, SHARE_PARTITION_KEY);
+        initSharePartition(shard, SHARE_PARTITION_KEY, 1);
 
         WriteShareGroupStateRequestData request1 = new 
WriteShareGroupStateRequestData()
-            .setGroupId(GROUP_ID)
-            .setTopics(List.of(new 
WriteShareGroupStateRequestData.WriteStateData()
-                .setTopicId(TOPIC_ID)
-                .setPartitions(List.of(new 
WriteShareGroupStateRequestData.PartitionData()
-                    .setPartition(PARTITION)
-                    .setStartOffset(0)
-                    .setStateEpoch(1)
-                    .setLeaderEpoch(5)
-                    .setStateBatches(List.of(new 
WriteShareGroupStateRequestData.StateBatch()
-                        .setFirstOffset(0)
-                        .setLastOffset(10)
-                        .setDeliveryCount((short) 1)
-                        .setDeliveryState((byte) 0)))))));
-
-        WriteShareGroupStateRequestData request2 = new 
WriteShareGroupStateRequestData()
             .setGroupId(GROUP_ID)
             .setTopics(List.of(new 
WriteShareGroupStateRequestData.WriteStateData()
                 .setTopicId(TOPIC_ID)
@@ -664,29 +513,13 @@ class ShareCoordinatorShardTest {
 
         CoordinatorResult<WriteShareGroupStateResponseData, CoordinatorRecord> 
result = shard.writeState(request1);
 
-        shard.replay(0L, 0L, (short) 0, result.records().get(0));
-
-        WriteShareGroupStateResponseData expectedData = 
WriteShareGroupStateResponse.toResponseData(TOPIC_ID, PARTITION);
-        List<CoordinatorRecord> expectedRecords = 
List.of(ShareCoordinatorRecordHelpers.newShareSnapshotRecord(
-            GROUP_ID, TOPIC_ID, PARTITION, 
ShareGroupOffset.fromRequest(request1.topics().get(0).partitions().get(0), 1, 
TIME.milliseconds())
-        ));
-
-        assertEquals(expectedData, result.response());
-        assertEquals(expectedRecords, result.records());
-
-        assertEquals(groupOffset(expectedRecords.get(0).value().message()),
-            shard.getShareStateMapValue(SHARE_PARTITION_KEY));
-        assertEquals(5, shard.getLeaderMapValue(SHARE_PARTITION_KEY));
-
-        result = shard.writeState(request2);
-
-        // Since the leader epoch in the second request was lower than the one 
in the first request, FENCED_LEADER_EPOCH error is expected.
-        expectedData = WriteShareGroupStateResponse.toErrorResponseData(
+        WriteShareGroupStateResponseData expectedData = 
WriteShareGroupStateResponse.toErrorResponseData(
             TOPIC_ID, PARTITION, Errors.FENCED_STATE_EPOCH, 
Errors.FENCED_STATE_EPOCH.message());
-        expectedRecords = List.of();
+        List<CoordinatorRecord> expectedRecords = List.of();
 
         assertEquals(expectedData, result.response());
         assertEquals(expectedRecords, result.records());
+        assertNotEquals(5, shard.getLeaderMapValue(SHARE_PARTITION_KEY));
 
         // No changes to the stateEpochMap.
         assertEquals(1, shard.getStateEpochMapValue(SHARE_PARTITION_KEY));
@@ -907,6 +740,7 @@ class ShareCoordinatorShardTest {
             .build();
 
         initSharePartition(shard, SHARE_PARTITION_KEY);
+        writeAndReplayRecord(shard, 0);
 
         // Set initial state.
         WriteShareGroupStateRequestData request = new 
WriteShareGroupStateRequestData()
@@ -943,17 +777,17 @@ class ShareCoordinatorShardTest {
 
         WriteShareGroupStateResponseData expectedData = 
WriteShareGroupStateResponse.toResponseData(TOPIC_ID, PARTITION);
         List<CoordinatorRecord> expectedRecords = 
List.of(ShareCoordinatorRecordHelpers.newShareSnapshotRecord(
-            GROUP_ID, TOPIC_ID, PARTITION, 
ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 1, 
TIME.milliseconds())
+            GROUP_ID, TOPIC_ID, PARTITION, 
ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 3, 
TIME.milliseconds())
         ));
 
         assertEquals(expectedData, result.response());
         assertEquals(expectedRecords, result.records());
 
         
assertEquals(groupOffset(ShareCoordinatorRecordHelpers.newShareSnapshotRecord(
-            GROUP_ID, TOPIC_ID, PARTITION, 
ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 1, 
TIME.milliseconds())
+            GROUP_ID, TOPIC_ID, PARTITION, 
ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 3, 
TIME.milliseconds())
         ).value().message()), 
shard.getShareStateMapValue(SHARE_PARTITION_KEY));
         assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY));
-        
verify(shard.getMetricsShard()).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME);
+        verify(shard.getMetricsShard(), 
times(3)).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME);
 
         // Acknowledge b1.
         WriteShareGroupStateRequestData requestUpdateB1 = new 
WriteShareGroupStateRequestData()
@@ -1004,7 +838,7 @@ class ShareCoordinatorShardTest {
             .setStartOffset(110)
             .setLeaderEpoch(0)
             .setStateEpoch(0)
-            .setSnapshotEpoch(3)    // since 2nd share snapshot
+            .setSnapshotEpoch(5)    // since subsequent share snapshot
             .setStateBatches(List.of(
                 new PersisterStateBatch(110, 119, (byte) 1, (short) 2),  // b2 
not lost
                 new PersisterStateBatch(120, 129, (byte) 2, (short) 1)
@@ -1023,7 +857,7 @@ class ShareCoordinatorShardTest {
             GROUP_ID, TOPIC_ID, PARTITION, offsetFinal
         ).value().message()), 
shard.getShareStateMapValue(SHARE_PARTITION_KEY));
         assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY));
-        verify(shard.getMetricsShard(), 
times(3)).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME);
+        verify(shard.getMetricsShard(), 
times(5)).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME);
     }
 
     @Test
@@ -1907,6 +1741,9 @@ class ShareCoordinatorShardTest {
     }
 
     private void initSharePartition(ShareCoordinatorShard shard, 
SharePartitionKey key) {
+        initSharePartition(shard, key, 0);
+    }
+    private void initSharePartition(ShareCoordinatorShard shard, 
SharePartitionKey key, int stateEpoch) {
         shard.replay(0L, 0L, (short) 0, CoordinatorRecord.record(
             new ShareSnapshotKey()
                 .setGroupId(key.groupId())
@@ -1914,7 +1751,7 @@ class ShareCoordinatorShardTest {
                 .setPartition(key.partition()),
             new ApiMessageAndVersion(
                 new ShareSnapshotValue()
-                    .setStateEpoch(0)
+                    .setStateEpoch(stateEpoch)
                     .setLeaderEpoch(-1)
                     .setStartOffset(-1),
                 (short) 0

Reply via email to