wernerdv commented on code in PR #21383:
URL: https://github.com/apache/kafka/pull/21383#discussion_r2762471317


##########
server/src/test/java/org/apache/kafka/server/FetchSessionTest.java:
##########
@@ -0,0 +1,1752 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.server;
+
+import org.apache.kafka.common.TopicIdPartition;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.compress.Compression;
+import org.apache.kafka.common.message.FetchResponseData;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.record.MemoryRecords;
+import org.apache.kafka.common.record.SimpleRecord;
+import org.apache.kafka.common.requests.FetchMetadata;
+import org.apache.kafka.common.requests.FetchRequest;
+import org.apache.kafka.common.requests.FetchRequest.PartitionData;
+import org.apache.kafka.common.requests.FetchResponse;
+import org.apache.kafka.common.utils.ImplicitLinkedHashCollection;
+import org.apache.kafka.server.FetchContext.FullFetchContext;
+import org.apache.kafka.server.FetchContext.IncrementalFetchContext;
+import org.apache.kafka.server.FetchContext.SessionErrorContext;
+import org.apache.kafka.server.FetchContext.SessionlessFetchContext;
+import org.apache.kafka.server.FetchSession.CachedPartition;
+import org.apache.kafka.server.FetchSession.FetchSessionCache;
+import org.apache.kafka.server.util.MockTime;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.junit.jupiter.params.provider.ValueSource;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
+import static org.apache.kafka.common.protocol.ApiKeys.FETCH;
+import static org.apache.kafka.common.requests.FetchMetadata.FINAL_EPOCH;
+import static org.apache.kafka.common.requests.FetchMetadata.INITIAL;
+import static 
org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+@Timeout(120)
+public class FetchSessionTest {
+    private static final List<TopicIdPartition> EMPTY_PART_LIST = 
List.copyOf(new ArrayList<>());
+
+    @AfterEach
+    public void afterEach() {
+        
FetchSessionCache.METRICS_GROUP.removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_SESSIONS);
+        
FetchSessionCache.METRICS_GROUP.removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED);
+        
FetchSessionCache.METRICS_GROUP.removeMetric(FetchSession.INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC);
+        FetchSessionCache.COUNTER.set(0);
+    }
+
+    @Test
+    public void testNewSessionId() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(3, 100, 
Integer.MAX_VALUE, 0);
+        for (int i = 0; i < 10_000; i++) {
+            int id = cacheShard.newSessionId();
+            assertTrue(id > 0);
+        }
+    }
+
+    @Test
+    public void testSessionCache() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(3, 100, 
Integer.MAX_VALUE, 0);
+        assertEquals(0, cacheShard.size());
+
+        int id1 = cacheShard.maybeCreateSession(0, false, 10, true, () -> 
dummyCreate(10));
+        int id2 = cacheShard.maybeCreateSession(10, false, 20, true, () -> 
dummyCreate(20));
+        int id3 = cacheShard.maybeCreateSession(20, false, 30, true, () -> 
dummyCreate(30));
+        assertEquals(INVALID_SESSION_ID, cacheShard.maybeCreateSession(30, 
false, 40, true, () -> dummyCreate(40)));
+        assertEquals(INVALID_SESSION_ID, cacheShard.maybeCreateSession(40, 
false, 5, true, () -> dummyCreate(5)));
+        assertCacheContains(cacheShard, id1, id2, id3);
+
+        cacheShard.touch(cacheShard.get(id1).orElseThrow(), 200);
+        int id4 = cacheShard.maybeCreateSession(210, false, 11, true, () -> 
dummyCreate(11));
+        assertCacheContains(cacheShard, id1, id3, id4);
+
+        cacheShard.touch(cacheShard.get(id1).orElseThrow(), 400);
+        cacheShard.touch(cacheShard.get(id3).orElseThrow(), 390);
+        cacheShard.touch(cacheShard.get(id4).orElseThrow(), 400);
+        int id5 = cacheShard.maybeCreateSession(410, false, 50, true, () -> 
dummyCreate(50));
+        assertCacheContains(cacheShard, id3, id4, id5);
+        assertEquals(INVALID_SESSION_ID, cacheShard.maybeCreateSession(410, 
false, 5, true, () -> dummyCreate(5)));
+
+        int id6 = cacheShard.maybeCreateSession(410, true, 5, true, () -> 
dummyCreate(5));
+        assertCacheContains(cacheShard, id3, id5, id6);
+    }
+
+    @Test
+    public void testResizeCachedSessions() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(2, 100, 
Integer.MAX_VALUE, 0);
+        assertEquals(0, cacheShard.totalPartitions());
+        assertEquals(0, cacheShard.size());
+        assertEquals(0, cacheShard.evictionsMeter().count());
+
+        int id1 = cacheShard.maybeCreateSession(0, false, 2, true, () -> 
dummyCreate(2));
+        assertTrue(id1 > 0);
+        assertCacheContains(cacheShard, id1);
+
+        FetchSession session1 = cacheShard.get(id1).orElseThrow();
+        assertEquals(2, session1.size());
+        assertEquals(2, cacheShard.totalPartitions());
+        assertEquals(1, cacheShard.size());
+        assertEquals(0, cacheShard.evictionsMeter().count());
+
+        int id2 = cacheShard.maybeCreateSession(0, false, 4, true, () -> 
dummyCreate(4));
+        FetchSession session2 = cacheShard.get(id2).orElseThrow();
+        assertTrue(id2 > 0);
+        assertCacheContains(cacheShard, id1, id2);
+        assertEquals(6, cacheShard.totalPartitions());
+        assertEquals(2, cacheShard.size());
+        assertEquals(0, cacheShard.evictionsMeter().count());
+
+        cacheShard.touch(session1, 200);
+        cacheShard.touch(session2, 200);
+        int id3 = cacheShard.maybeCreateSession(200, false, 5, true, () -> 
dummyCreate(5));
+        assertTrue(id3 > 0);
+        assertCacheContains(cacheShard, id2, id3);
+        assertEquals(9, cacheShard.totalPartitions());
+        assertEquals(2, cacheShard.size());
+        assertEquals(1, cacheShard.evictionsMeter().count());
+
+        cacheShard.remove(id3);
+        assertCacheContains(cacheShard, id2);
+        assertEquals(1, cacheShard.size());
+        assertEquals(1, cacheShard.evictionsMeter().count());
+        assertEquals(4, cacheShard.totalPartitions());
+
+        Iterator<CachedPartition> iter = session2.partitionMap().iterator();
+        iter.next();
+        iter.remove();
+        assertEquals(3, session2.size());
+        assertEquals(4, session2.cachedSize());
+
+        cacheShard.touch(session2, session2.lastUsedMs());
+        assertEquals(3, cacheShard.totalPartitions());
+    }
+
+    @Test
+    public void testCachedLeaderEpoch() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+
+        Map<String, Uuid> topicIds = Map.of("foo", Uuid.randomUuid(), "bar", 
Uuid.randomUuid());
+        TopicIdPartition tp0 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 0));
+        TopicIdPartition tp1 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 1));
+        TopicIdPartition tp2 = new TopicIdPartition(topicIds.get("bar"), new 
TopicPartition("bar", 1));
+        Map<Uuid, String> topicNames = 
topicIds.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, 
Map.Entry::getKey));
+
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> requestData1 
= new LinkedHashMap<>();
+        requestData1.put(tp0.topicPartition(), new 
PartitionData(tp0.topicId(), 0, 0, 100, Optional.empty()));
+        requestData1.put(tp1.topicPartition(), new 
PartitionData(tp1.topicId(), 10, 0, 100, Optional.of(1)));
+        requestData1.put(tp2.topicPartition(), new 
PartitionData(tp2.topicId(), 10, 0, 100, Optional.of(2)));
+
+        FetchRequest request1 = createRequest(INITIAL, requestData1, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context1 = newContext(fetchManager, request1, topicNames);
+        Map<TopicIdPartition, Optional<Integer>> epochs1 = 
cachedLeaderEpochs(context1);
+        assertEquals(Optional.empty(), epochs1.get(tp0));
+        assertEquals(Optional.of(1), epochs1.get(tp1));
+        assertEquals(Optional.of(2), epochs1.get(tp2));
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
response = new LinkedHashMap<>();
+        response.put(tp0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp0.partition())
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        response.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp1.partition())
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        response.put(tp2, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp2.partition())
+            .setHighWatermark(5)
+            .setLastStableOffset(5)
+            .setLogStartOffset(5));
+
+        int sessionId = context1.updateAndGenerateResponseData(response, 
List.of()).sessionId();
+
+        // With no changes, the cached epochs should remain the same
+        FetchRequest request2 = createRequest(new FetchMetadata(sessionId, 1), 
new LinkedHashMap<>(),
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context2 = newContext(fetchManager, request2, topicNames);
+        Map<TopicIdPartition, Optional<Integer>> epochs2 = 
cachedLeaderEpochs(context2);
+        assertEquals(Optional.empty(), epochs1.get(tp0));
+        assertEquals(Optional.of(1), epochs2.get(tp1));
+        assertEquals(Optional.of(2), epochs2.get(tp2));
+        context2.updateAndGenerateResponseData(response, 
List.of()).sessionId();
+
+        // Now verify we can change the leader epoch and the context is updated
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> requestData3 
= new LinkedHashMap<>();
+        requestData3.put(tp0.topicPartition(), new 
PartitionData(tp0.topicId(), 0, 0, 100, Optional.of(6)));
+        requestData3.put(tp1.topicPartition(), new 
PartitionData(tp1.topicId(), 10, 0, 100, Optional.empty()));
+        requestData3.put(tp2.topicPartition(), new 
PartitionData(tp2.topicId(), 10, 0, 100, Optional.of(3)));
+
+        FetchRequest request3 = createRequest(new FetchMetadata(sessionId, 2), 
requestData3,
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context3 = newContext(fetchManager, request3, topicNames);
+        Map<TopicIdPartition, Optional<Integer>> epochs3 = 
cachedLeaderEpochs(context3);
+        assertEquals(Optional.of(6), epochs3.get(tp0));
+        assertEquals(Optional.empty(), epochs3.get(tp1));
+        assertEquals(Optional.of(3), epochs3.get(tp2));
+    }
+
+    @Test
+    public void testLastFetchedEpoch() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+
+        Map<String, Uuid> topicIds = Map.of("foo", Uuid.randomUuid(), "bar", 
Uuid.randomUuid());
+        Map<Uuid, String> topicNames = 
topicIds.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, 
Map.Entry::getKey));
+        TopicIdPartition tp0 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 0));
+        TopicIdPartition tp1 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 1));
+        TopicIdPartition tp2 = new TopicIdPartition(topicIds.get("bar"), new 
TopicPartition("bar", 1));
+
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> requestData1 
= new LinkedHashMap<>();
+        requestData1.put(tp0.topicPartition(), new 
PartitionData(tp0.topicId(), 0, 0, 100, Optional.empty(), Optional.empty()));
+        requestData1.put(tp1.topicPartition(), new 
PartitionData(tp1.topicId(), 10, 0, 100, Optional.of(1), Optional.empty()));
+        requestData1.put(tp2.topicPartition(), new 
PartitionData(tp2.topicId(), 10, 0, 100, Optional.of(2), Optional.of(1)));
+
+        FetchRequest request1 = createRequest(INITIAL, requestData1, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context1 = newContext(fetchManager, request1, topicNames);
+        assertEquals(Map.of(tp0, Optional.empty(), tp1, Optional.of(1), tp2, 
Optional.of(2)), cachedLeaderEpochs(context1));
+        assertEquals(Map.of(tp0, Optional.empty(), tp1, Optional.empty(), tp2, 
Optional.of(1)), cachedLastFetchedEpochs(context1));
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
response = new LinkedHashMap<>();
+        response.put(tp0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp0.partition())
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        response.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp1.partition())
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        response.put(tp2, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp2.partition())
+            .setHighWatermark(5)
+            .setLastStableOffset(5)
+            .setLogStartOffset(5));
+
+        int sessionId = context1.updateAndGenerateResponseData(response, 
List.of()).sessionId();
+
+        // With no changes, the cached epochs should remain the same
+        FetchRequest request2 = createRequest(new FetchMetadata(sessionId, 1), 
new LinkedHashMap<>(),
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context2 = newContext(fetchManager, request2, topicNames);
+        assertEquals(Map.of(tp0, Optional.empty(), tp1, Optional.of(1), tp2, 
Optional.of(2)), cachedLeaderEpochs(context2));
+        assertEquals(Map.of(tp0, Optional.empty(), tp1, Optional.empty(), tp2, 
Optional.of(1)), cachedLastFetchedEpochs(context2));
+        context2.updateAndGenerateResponseData(response, 
List.of()).sessionId();
+
+        // Now verify we can change the leader epoch and the context is updated
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> requestData3 
= new LinkedHashMap<>();
+        requestData3.put(tp0.topicPartition(), new 
PartitionData(tp0.topicId(), 0, 0, 100, Optional.of(6), Optional.of(5)));
+        requestData3.put(tp1.topicPartition(), new 
PartitionData(tp1.topicId(), 10, 0, 100, Optional.empty(), Optional.empty()));
+        requestData3.put(tp2.topicPartition(), new 
PartitionData(tp2.topicId(), 10, 0, 100, Optional.of(3), Optional.of(3)));
+
+        FetchRequest request3 = createRequest(new FetchMetadata(sessionId, 2), 
requestData3, EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context3 = newContext(fetchManager, request3, topicNames);
+        assertEquals(Map.of(tp0, Optional.of(6), tp1, Optional.empty(), tp2, 
Optional.of(3)), cachedLeaderEpochs(context3));
+        assertEquals(Map.of(tp0, Optional.of(5), tp1, Optional.empty(), tp2, 
Optional.of(3)), cachedLastFetchedEpochs(context2));
+    }
+
+    @Test
+    public void testFetchRequests() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+        Map<Uuid, String> topicNames = Map.of(Uuid.randomUuid(), "foo", 
Uuid.randomUuid(), "bar");
+        Map<String, Uuid> topicIds = 
topicNames.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, 
Map.Entry::getKey));
+        TopicIdPartition tp0 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 0));
+        TopicIdPartition tp1 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 1));
+        TopicIdPartition tp2 = new TopicIdPartition(topicIds.get("bar"), new 
TopicPartition("bar", 0));
+        TopicIdPartition tp3 = new TopicIdPartition(topicIds.get("bar"), new 
TopicPartition("bar", 1));
+
+        // Verify that SESSIONLESS requests get a SessionlessFetchContext
+        FetchRequest request = createRequest(FetchMetadata.LEGACY, new 
HashMap<>(), EMPTY_PART_LIST, true, FETCH.latestVersion());
+        FetchContext context = newContext(fetchManager, request, topicNames);
+        assertInstanceOf(SessionlessFetchContext.class, context);
+
+        // Create a new fetch session with a FULL fetch request
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqData2 = 
new LinkedHashMap<>();
+        reqData2.put(tp0.topicPartition(), new PartitionData(tp0.topicId(), 0, 
0, 100, Optional.empty()));
+        reqData2.put(tp1.topicPartition(), new PartitionData(tp1.topicId(), 
10, 0, 100, Optional.empty()));
+        FetchRequest request2 = createRequest(INITIAL, reqData2, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context2 = newContext(fetchManager, request2, topicNames);
+        assertInstanceOf(FullFetchContext.class, context2);
+
+        Iterator<Map.Entry<TopicPartition, FetchRequest.PartitionData>> 
reqData2Iter = reqData2.entrySet().iterator();
+        context2.foreachPartition((topicIdPart, data) -> {
+            Map.Entry<TopicPartition, FetchRequest.PartitionData> entry = 
reqData2Iter.next();
+            assertEquals(entry.getKey(), topicIdPart.topicPartition());
+            assertEquals(topicIds.get(entry.getKey().topic()), 
topicIdPart.topicId());
+            assertEquals(entry.getValue(), data);
+        });
+        assertEquals(0, context2.getFetchOffset(tp0).orElseThrow());
+        assertEquals(10, context2.getFetchOffset(tp1).orElseThrow());
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData2 = new LinkedHashMap<>();
+        respData2.put(tp0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        respData2.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse resp2 = 
context2.updateAndGenerateResponseData(respData2, List.of());
+        assertEquals(Errors.NONE, resp2.error());
+        assertTrue(resp2.sessionId() != INVALID_SESSION_ID);
+        assertEquals(
+            respData2.entrySet()
+                .stream()
+                .collect(Collectors.toMap(entry -> 
entry.getKey().topicPartition(), Map.Entry::getValue)),
+            resp2.responseData(topicNames, request2.version())
+        );
+
+        // Test trying to create a new session with an invalid epoch
+        FetchRequest request3 = createRequest(new 
FetchMetadata(resp2.sessionId(), 5), reqData2,
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context3 = newContext(fetchManager, request3, topicNames);
+        assertInstanceOf(SessionErrorContext.class, context3);
+        assertEquals(Errors.INVALID_FETCH_SESSION_EPOCH, 
context3.updateAndGenerateResponseData(respData2, List.of()).error());
+
+        // Test trying to create a new session with a non-existent session id
+        FetchRequest request4 = createRequest(new 
FetchMetadata(resp2.sessionId() + 1, 1), reqData2,
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context4 = newContext(fetchManager, request4, topicNames);
+        assertEquals(Errors.FETCH_SESSION_ID_NOT_FOUND, 
context4.updateAndGenerateResponseData(respData2, List.of()).error());
+
+        // Continue the first fetch session we created.
+        FetchRequest request5 = createRequest(new 
FetchMetadata(resp2.sessionId(), 1), new LinkedHashMap<>(),
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context5 = newContext(fetchManager, request5, topicNames);
+        assertInstanceOf(IncrementalFetchContext.class, context5);
+
+        Iterator<Map.Entry<TopicPartition, FetchRequest.PartitionData>> 
reqData5Iter = reqData2.entrySet().iterator();
+        context5.foreachPartition((topicIdPart, data) -> {
+            Map.Entry<TopicPartition, FetchRequest.PartitionData> entry = 
reqData5Iter.next();
+            assertEquals(entry.getKey(), topicIdPart.topicPartition());
+            assertEquals(topicIds.get(entry.getKey().topic()), 
topicIdPart.topicId());
+            assertEquals(entry.getValue(), data);
+        });
+        assertEquals(10, context5.getFetchOffset(tp1).orElseThrow());
+
+        FetchResponse resp5 = 
context5.updateAndGenerateResponseData(respData2, List.of());
+        assertEquals(Errors.NONE, resp5.error());
+        assertEquals(resp2.sessionId(), resp5.sessionId());
+        assertEquals(0, resp5.responseData(topicNames, 
request5.version()).size());
+
+        // Test setting an invalid fetch session epoch.
+        FetchRequest request6 = createRequest(new 
FetchMetadata(resp2.sessionId(), 5), reqData2,
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context6 = newContext(fetchManager, request6, topicNames);
+        assertInstanceOf(SessionErrorContext.class, context6);
+        assertEquals(Errors.INVALID_FETCH_SESSION_EPOCH, 
context6.updateAndGenerateResponseData(respData2, List.of()).error());
+
+        // Test generating a throttled response for the incremental fetch 
session
+        FetchRequest request7 = createRequest(new 
FetchMetadata(resp2.sessionId(), 2), new LinkedHashMap<>(),
+            EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context7 = newContext(fetchManager, request7, topicNames);
+        FetchResponse resp7 = context7.getThrottledResponse(100, List.of());
+        assertEquals(Errors.NONE, resp7.error());
+        assertEquals(resp2.sessionId(), resp7.sessionId());
+        assertEquals(100, resp7.throttleTimeMs());
+
+        // Close the incremental fetch session.
+        int prevSessionId = resp5.sessionId();
+        int nextSessionId;
+        do {
+            LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqData8 
= new LinkedHashMap<>();
+            reqData8.put(tp2.topicPartition(), new 
PartitionData(tp2.topicId(), 0, 0, 100, Optional.empty()));
+            reqData8.put(tp3.topicPartition(), new 
PartitionData(tp3.topicId(), 10, 0, 100, Optional.empty()));
+            FetchRequest request8 = createRequest(new 
FetchMetadata(prevSessionId, FINAL_EPOCH), reqData8,
+                EMPTY_PART_LIST, false, FETCH.latestVersion());
+            FetchContext context8 = newContext(fetchManager, request8, 
topicNames);
+            assertInstanceOf(SessionlessFetchContext.class, context8);
+            assertEquals(0, cacheShard.size());
+
+            LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData8 = new LinkedHashMap<>();
+            respData8.put(tp2, new FetchResponseData.PartitionData()
+                .setPartitionIndex(0)
+                .setHighWatermark(100)
+                .setLastStableOffset(100)
+                .setLogStartOffset(100));
+            respData8.put(tp3, new FetchResponseData.PartitionData()
+                .setPartitionIndex(1)
+                .setHighWatermark(100)
+                .setLastStableOffset(100)
+                .setLogStartOffset(100));
+            FetchResponse resp8 = 
context8.updateAndGenerateResponseData(respData8, List.of());
+            assertEquals(Errors.NONE, resp8.error());
+
+            nextSessionId = resp8.sessionId();
+        } while (nextSessionId == prevSessionId);
+    }
+
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    public void testIncrementalFetchSession(boolean usesTopicIds) {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+        Map<Uuid, String> topicNames = usesTopicIds
+            ? Map.of(Uuid.randomUuid(), "foo", Uuid.randomUuid(), "bar")
+            : Map.of();
+        Map<String, Uuid> topicIds = 
topicNames.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, 
Map.Entry::getKey));
+        short version = usesTopicIds ? FETCH.latestVersion() : (short) 12;
+        Uuid fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID);
+        Uuid barId = topicIds.getOrDefault("bar", Uuid.ZERO_UUID);
+        TopicIdPartition tp0 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 0));
+        TopicIdPartition tp1 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 1));
+        TopicIdPartition tp2 = new TopicIdPartition(barId, new 
TopicPartition("bar", 0));
+
+        // Create a new fetch session with foo-0 and foo-1
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqData1 = 
new LinkedHashMap<>();
+        reqData1.put(tp0.topicPartition(), new PartitionData(fooId, 0, 0, 100, 
Optional.empty()));
+        reqData1.put(tp1.topicPartition(), new PartitionData(fooId, 10, 0, 
100, Optional.empty()));
+        FetchRequest request1 = createRequest(INITIAL, reqData1, 
EMPTY_PART_LIST, false, version);
+        FetchContext context1 = newContext(fetchManager, request1, topicNames);
+        assertInstanceOf(FullFetchContext.class, context1);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData1 = new LinkedHashMap<>();
+        respData1.put(tp0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        respData1.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse resp1 = 
context1.updateAndGenerateResponseData(respData1, List.of());
+        assertEquals(Errors.NONE, resp1.error());
+        assertTrue(resp1.sessionId() != INVALID_SESSION_ID);
+        assertEquals(2, resp1.responseData(topicNames, 
request1.version()).size());
+
+        // Create an incremental fetch request that removes foo-0 and adds 
bar-0
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqData2 = 
new LinkedHashMap<>();
+        reqData2.put(tp2.topicPartition(), new PartitionData(barId, 15, 0, 0, 
Optional.empty()));
+        FetchRequest request2 = createRequest(new 
FetchMetadata(resp1.sessionId(), 1), reqData2, List.of(tp0), false, version);
+        FetchContext context2 = newContext(fetchManager, request2, topicNames);
+        assertInstanceOf(IncrementalFetchContext.class, context2);
+
+        Set<TopicIdPartition> parts = new LinkedHashSet<>();
+        parts.add(tp1);
+        parts.add(tp2);
+        Iterator<TopicIdPartition> reqData2Iter = parts.iterator();
+        context2.foreachPartition((topicIdPart, data) -> 
assertEquals(reqData2Iter.next(), topicIdPart));
+        assertEquals(Optional.empty(), context2.getFetchOffset(tp0));
+        assertEquals(10, context2.getFetchOffset(tp1).orElseThrow());
+        assertEquals(15, context2.getFetchOffset(tp2).orElseThrow());
+        assertEquals(Optional.empty(), context2.getFetchOffset(new 
TopicIdPartition(barId, new TopicPartition("bar", 2))));
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData2 = new LinkedHashMap<>();
+        respData2.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        respData2.put(tp2, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse resp2 = 
context2.updateAndGenerateResponseData(respData2, List.of());
+        assertEquals(Errors.NONE, resp2.error());
+        assertEquals(1, resp2.responseData(topicNames, 
request2.version()).size());
+        assertTrue(resp2.sessionId() > 0);
+    }
+
+    // This test simulates a request without IDs sent to a broker with IDs.
+    @Test
+    public void testFetchSessionWithUnknownIdOldRequestVersion() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+        Map<Uuid, String> topicNames = Map.of(Uuid.randomUuid(), "foo", 
Uuid.randomUuid(), "bar");
+        Map<String, Uuid> topicIds = 
topicNames.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, 
Map.Entry::getKey));
+        TopicIdPartition tp0 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 0));
+        TopicIdPartition tp1 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 1));
+
+        // Create a new fetch session with foo-0 and foo-1
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqData1 = 
new LinkedHashMap<>();
+        reqData1.put(tp0.topicPartition(), new PartitionData(tp0.topicId(), 0, 
0, 100, Optional.empty()));
+        reqData1.put(tp1.topicPartition(), new PartitionData(Uuid.ZERO_UUID, 
10, 0, 100, Optional.empty()));
+        FetchRequest request1 = createRequestWithoutTopicIds(INITIAL, 
reqData1);
+        // Simulate unknown topic ID for foo.
+        Map<Uuid, String> topicNamesOnlyBar = Map.of(topicIds.get("bar"), 
"bar");
+        // We should not throw error since we have an older request version.
+        FetchContext context1 = newContext(fetchManager, request1, 
topicNamesOnlyBar);
+        assertInstanceOf(FullFetchContext.class, context1);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData1 = new LinkedHashMap<>();
+        respData1.put(tp0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        respData1.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse resp1 = 
context1.updateAndGenerateResponseData(respData1, List.of());
+        // Since we are ignoring IDs, we should have no errors.
+        assertEquals(Errors.NONE, resp1.error());
+        assertTrue(resp1.sessionId() != INVALID_SESSION_ID);
+        assertEquals(2, resp1.responseData(topicNames, 
request1.version()).size());
+        resp1.responseData(topicNames, request1.version()).forEach((tp, resp) 
->
+            assertEquals(Errors.NONE.code(), resp.errorCode()));
+    }
+
+    @Test
+    public void testFetchSessionWithUnknownId() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+        Uuid fooId = Uuid.randomUuid();
+        Uuid barId = Uuid.randomUuid();
+        Uuid zarId = Uuid.randomUuid();
+        Map<Uuid, String> topicNames = Map.of(fooId, "foo", barId, "bar", 
zarId, "zar");
+        TopicIdPartition foo0 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 0));
+        TopicIdPartition foo1 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 1));
+        TopicIdPartition zar0 = new TopicIdPartition(zarId, new 
TopicPartition("zar", 0));
+        TopicIdPartition emptyFoo0 = new TopicIdPartition(fooId, new 
TopicPartition(null, 0));
+        TopicIdPartition emptyFoo1 = new TopicIdPartition(fooId, new 
TopicPartition(null, 1));
+        TopicIdPartition emptyZar0 = new TopicIdPartition(zarId, new 
TopicPartition(null, 0));
+
+        // Create a new fetch session with foo-0 and foo-1
+        LinkedHashMap<TopicPartition, PartitionData> reqData1 = new 
LinkedHashMap<>();
+        reqData1.put(foo0.topicPartition(), new PartitionData(foo0.topicId(), 
0, 0, 100, Optional.empty()));
+        reqData1.put(foo1.topicPartition(), new PartitionData(foo1.topicId(), 
10, 0, 100, Optional.empty()));
+        reqData1.put(zar0.topicPartition(), new PartitionData(zar0.topicId(), 
10, 0, 100, Optional.empty()));
+        FetchRequest request1 = createRequest(INITIAL, reqData1, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        // Simulate unknown topic ID for foo.
+        Map<Uuid, String> topicNamesOnlyBar = Map.of(barId, "bar");
+        // We should not throw error since we have an older request version.
+        FetchContext context1 = newContext(fetchManager, request1, 
topicNamesOnlyBar);
+        assertInstanceOf(FullFetchContext.class, context1);
+        assertPartitionsOrder(context1, List.of(emptyFoo0, emptyFoo1, 
emptyZar0));
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData1 = new LinkedHashMap<>();
+        respData1.put(emptyFoo0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setErrorCode(Errors.UNKNOWN_TOPIC_ID.code()));
+        respData1.put(emptyFoo1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setErrorCode(Errors.UNKNOWN_TOPIC_ID.code()));
+        respData1.put(emptyZar0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setErrorCode(Errors.UNKNOWN_TOPIC_ID.code()));
+        FetchResponse resp1 = 
context1.updateAndGenerateResponseData(respData1, List.of());
+        // On the latest request version, we should have unknown topic ID 
errors.
+        assertEquals(Errors.NONE, resp1.error());
+        assertTrue(resp1.sessionId() != INVALID_SESSION_ID);
+        assertEquals(
+            Map.of(
+                foo0.topicPartition(), Errors.UNKNOWN_TOPIC_ID.code(),
+                foo1.topicPartition(), Errors.UNKNOWN_TOPIC_ID.code(),
+                zar0.topicPartition(), Errors.UNKNOWN_TOPIC_ID.code()
+            ),
+            resp1.responseData(topicNames, request1.version()).entrySet()
+                .stream()
+                .collect(Collectors.toMap(Map.Entry::getKey, entry -> 
entry.getValue().errorCode()))
+        );
+
+        // Create an incremental request where we resolve the partitions
+        FetchRequest request2 = createRequest(new 
FetchMetadata(resp1.sessionId(), 1), new LinkedHashMap<>(), EMPTY_PART_LIST, 
false, FETCH.latestVersion());
+        Map<Uuid, String> topicNamesNoZar = Map.of(fooId, "foo", barId, "bar");
+        FetchContext context2 = newContext(fetchManager, request2, 
topicNamesNoZar);
+        assertInstanceOf(IncrementalFetchContext.class, context2);
+        // Topic names in the session but not in the request are lazily 
resolved via foreachPartition. Resolve foo topic IDs here.
+        assertPartitionsOrder(context2, List.of(foo0, foo1, emptyZar0));
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData2 = new LinkedHashMap<>();
+        respData2.put(foo0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        respData2.put(foo1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        respData2.put(emptyZar0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setErrorCode(Errors.UNKNOWN_TOPIC_ID.code()));
+        FetchResponse resp2 = 
context2.updateAndGenerateResponseData(respData2, List.of());
+        // Since we are ignoring IDs, we should have no errors.
+        assertEquals(Errors.NONE, resp2.error());
+        assertTrue(resp2.sessionId() != INVALID_SESSION_ID);
+        assertEquals(3, resp2.responseData(topicNames, 
request2.version()).size());
+        assertEquals(
+            Map.of(
+                foo0.topicPartition(), Errors.NONE.code(),
+                foo1.topicPartition(), Errors.NONE.code(),
+                zar0.topicPartition(), Errors.UNKNOWN_TOPIC_ID.code()
+            ),
+            resp2.responseData(topicNames, request2.version()).entrySet()
+                .stream()
+                .collect(Collectors.toMap(Map.Entry::getKey, entry -> 
entry.getValue().errorCode()))
+        );
+    }
+
+    @Test
+    public void testIncrementalFetchSessionWithIdsWhenSessionDoesNotUseIds() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+        Map<Uuid, String> topicNames = new HashMap<>();
+        TopicIdPartition foo0 = new TopicIdPartition(Uuid.ZERO_UUID, new 
TopicPartition("foo", 0));
+
+        // Create a new fetch session with foo-0
+        LinkedHashMap<TopicPartition, PartitionData> reqData1 = new 
LinkedHashMap<>();
+        reqData1.put(foo0.topicPartition(), new PartitionData(Uuid.ZERO_UUID, 
0, 0, 100, Optional.empty()));
+        FetchRequest request1 = createRequestWithoutTopicIds(INITIAL, 
reqData1);
+        // Start a fetch session using a request version that does not use 
topic IDs.
+        FetchContext context1 = newContext(fetchManager, request1, topicNames);
+        assertInstanceOf(FullFetchContext.class, context1);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData1 = new LinkedHashMap<>();
+        respData1.put(foo0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        FetchResponse resp1 = 
context1.updateAndGenerateResponseData(respData1, List.of());
+        assertEquals(Errors.NONE, resp1.error());
+        assertTrue(resp1.sessionId() != INVALID_SESSION_ID);
+
+        // Create an incremental fetch request as though no topics changed. 
However, send a v13 request.
+        // Also simulate the topic ID found on the server.
+        topicNames.put(Uuid.randomUuid(), "foo");
+        FetchRequest request2 = createRequest(new 
FetchMetadata(resp1.sessionId(), 1), new LinkedHashMap<>(), EMPTY_PART_LIST, 
false, FETCH.latestVersion());
+        FetchContext context2 = newContext(fetchManager, request2, topicNames);
+
+        assertInstanceOf(SessionErrorContext.class, context2);
+        assertEquals(Errors.FETCH_SESSION_TOPIC_ID_ERROR, 
context2.updateAndGenerateResponseData(new LinkedHashMap<>(), 
List.of()).error());
+    }
+
+    @Test
+    public void testIncrementalFetchSessionWithoutIdsWhenSessionUsesIds() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+        Uuid fooId = Uuid.randomUuid();
+        Map<Uuid, String> topicNames = new HashMap<>();
+        topicNames.put(fooId, "foo");
+        TopicIdPartition foo0 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 0));
+
+        // Create a new fetch session with foo-0
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqData1 = 
new LinkedHashMap<>();
+        reqData1.put(foo0.topicPartition(), new PartitionData(fooId, 0, 0, 
100, Optional.empty()));
+        FetchRequest request1 = createRequest(INITIAL, reqData1, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        // Start a fetch session using a request version that uses topic IDs.
+        FetchContext context1 = newContext(fetchManager, request1, topicNames);
+        assertInstanceOf(FullFetchContext.class, context1);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData1 = new LinkedHashMap<>();
+        respData1.put(foo0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        FetchResponse resp1 = 
context1.updateAndGenerateResponseData(respData1, List.of());
+        assertEquals(Errors.NONE, resp1.error());
+        assertTrue(resp1.sessionId() != INVALID_SESSION_ID);
+
+        // Create an incremental fetch request as though no topics changed. 
However, send a v12 request.
+        // Also simulate the topic ID not found on the server
+        topicNames.remove(fooId);
+
+        FetchRequest request2 = createRequestWithoutTopicIds(new 
FetchMetadata(resp1.sessionId(), 1), new LinkedHashMap<>());
+        FetchContext context2 = newContext(fetchManager, request2, topicNames);
+
+        assertInstanceOf(SessionErrorContext.class, context2);
+        assertEquals(Errors.FETCH_SESSION_TOPIC_ID_ERROR, 
context2.updateAndGenerateResponseData(new LinkedHashMap<>(), 
List.of()).error());
+    }
+
+    // This test simulates a session where the topic ID changes broker side 
(the one handling the request) in both the metadata cache and the log
+    // -- as though the topic is deleted and recreated.
+    @Test
+    public void testFetchSessionUpdateTopicIdsBrokerSide() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+        Map<Uuid, String> topicNames = Map.of(Uuid.randomUuid(), "foo", 
Uuid.randomUuid(), "bar");
+        Map<String, Uuid> topicIds = 
topicNames.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, 
Map.Entry::getKey));
+        TopicIdPartition tp0 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 0));
+        TopicIdPartition tp1 = new TopicIdPartition(topicIds.get("bar"), new 
TopicPartition("bar", 1));
+
+        // Create a new fetch session with foo-0 and bar-1
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqData1 = 
new LinkedHashMap<>();
+        reqData1.put(tp0.topicPartition(), new PartitionData(tp0.topicId(), 0, 
0, 100, Optional.empty()));
+        reqData1.put(tp1.topicPartition(), new PartitionData(tp1.topicId(), 
10, 0, 100, Optional.empty()));
+        FetchRequest request1 = createRequest(INITIAL, reqData1, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        // Start a fetch session. Simulate unknown partition foo-0.
+        FetchContext context1 = newContext(fetchManager, request1, topicNames);
+        assertInstanceOf(FullFetchContext.class, context1);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData1 = new LinkedHashMap<>();
+        respData1.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        respData1.put(tp0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(-1)
+            .setLastStableOffset(-1)
+            .setLogStartOffset(-1)
+            .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()));
+        FetchResponse resp1 = 
context1.updateAndGenerateResponseData(respData1, List.of());
+        assertEquals(Errors.NONE, resp1.error());
+        assertTrue(resp1.sessionId() != INVALID_SESSION_ID);
+        assertEquals(2, resp1.responseData(topicNames, 
request1.version()).size());
+
+        // Create an incremental fetch request as though no topics changed.
+        FetchRequest request2 = createRequest(new 
FetchMetadata(resp1.sessionId(), 1), new LinkedHashMap<>(), EMPTY_PART_LIST, 
false, FETCH.latestVersion());
+        // Simulate ID changing on server.
+        Map<Uuid, String> topicNamesFooChanged = Map.of(topicIds.get("bar"), 
"bar", Uuid.randomUuid(), "foo");
+        FetchContext context2 = newContext(fetchManager, request2, 
topicNamesFooChanged);
+        assertInstanceOf(IncrementalFetchContext.class, context2);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData2 = new LinkedHashMap<>();
+        // Likely if the topic ID is different in the broker, it will be 
different in the log. Simulate the log check finding an inconsistent ID.
+        respData2.put(tp0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(-1)
+            .setLastStableOffset(-1)
+            .setLogStartOffset(-1)
+            .setErrorCode(Errors.INCONSISTENT_TOPIC_ID.code()));
+        FetchResponse resp2 = 
context2.updateAndGenerateResponseData(respData2, List.of());
+        assertEquals(Errors.NONE, resp2.error());
+        assertTrue(resp2.sessionId() > 0);
+
+        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> 
responseData2 = resp2.responseData(topicNames, request2.version());
+        // We should have the inconsistent topic ID error on the partition
+        assertEquals(Errors.INCONSISTENT_TOPIC_ID.code(), 
responseData2.get(tp0.topicPartition()).errorCode());
+    }
+
+    @Test
+    public void testResolveUnknownPartitions() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+
+        TopicIdPartition foo = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
+        TopicIdPartition bar = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("bar", 0));
+        TopicIdPartition zar = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("zar", 0));
+
+        TopicIdPartition fooUnresolved = new TopicIdPartition(foo.topicId(), 
new TopicPartition(null, foo.partition()));
+        TopicIdPartition barUnresolved = new TopicIdPartition(bar.topicId(), 
new TopicPartition(null, bar.partition()));
+        TopicIdPartition zarUnresolved = new TopicIdPartition(zar.topicId(), 
new TopicPartition(null, zar.partition()));
+
+        // The metadata cache does not know about the topic.
+        FetchContext context1 = newContext(INITIAL, List.of(foo, bar, zar), 
fetchManager, Map.of());
+
+        // So the context contains unresolved partitions.
+        assertInstanceOf(FullFetchContext.class, context1);
+        assertPartitionsOrder(context1, List.of(fooUnresolved, barUnresolved, 
zarUnresolved));
+
+        // The response is sent back to create the session.
+        int sessionId = updateAndGenerateResponseDataSessionId(context1);
+
+        // The metadata cache only knows about foo.
+        FetchContext context2 = newContext(new FetchMetadata(sessionId, 1), 
List.of(), fetchManager, Map.of(foo.topicId(), foo.topic()));
+
+        // So foo is resolved but not the others.
+        assertInstanceOf(IncrementalFetchContext.class, context2);
+        assertPartitionsOrder(context2, List.of(foo, barUnresolved, 
zarUnresolved));
+
+        updateAndGenerateResponseDataSessionId(context2);
+
+        // The metadata cache knows about foo and bar.
+        FetchContext context3 = newContext(
+            new FetchMetadata(sessionId, 2),
+            List.of(bar),
+            fetchManager,
+            Map.of(foo.topicId(), foo.topic(), bar.topicId(), bar.topic())
+        );
+
+        // So foo and bar are resolved.
+        assertInstanceOf(IncrementalFetchContext.class, context3);
+        assertPartitionsOrder(context3, List.of(foo, bar, zarUnresolved));
+
+        updateAndGenerateResponseDataSessionId(context3);
+
+        // The metadata cache knows about all topics.
+        FetchContext context4 = newContext(
+            new FetchMetadata(sessionId, 3),
+            List.of(),
+            fetchManager,
+            Map.of(foo.topicId(), foo.topic(), bar.topicId(), bar.topic(), 
zar.topicId(), zar.topic())
+        );
+
+        // So all topics are resolved.
+        assertInstanceOf(IncrementalFetchContext.class, context4);
+        assertPartitionsOrder(context4, List.of(foo, bar, zar));
+
+        updateAndGenerateResponseDataSessionId(context4);
+
+        // The metadata cache does not know about the topics anymore (e.g. 
deleted).
+        FetchContext context5 = newContext(new FetchMetadata(sessionId, 4), 
List.of(), fetchManager, Map.of());
+
+        // All topics remain resolved.
+        assertInstanceOf(IncrementalFetchContext.class, context5);
+        assertPartitionsOrder(context4, List.of(foo, bar, zar));
+    }
+
+    // This test simulates trying to forget a topic partition with all 
possible topic ID usages for both requests.
+    @ParameterizedTest
+    @MethodSource({("idUsageCombinations")})
+    public void testToForgetPartitions(boolean fooStartsResolved, boolean 
fooEndsResolved) {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+
+        TopicIdPartition foo = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
+        TopicIdPartition bar = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("bar", 0));
+
+        TopicIdPartition fooUnresolved = new TopicIdPartition(foo.topicId(), 
new TopicPartition(null, foo.partition()));
+        TopicIdPartition barUnresolved = new TopicIdPartition(bar.topicId(), 
new TopicPartition(null, bar.partition()));
+
+        // Create a new context where foo's resolution depends on 
fooStartsResolved and bar is unresolved.
+        Map<Uuid, String> context1Names = fooStartsResolved ? 
Map.of(foo.topicId(), foo.topic()) : Map.of();
+        TopicIdPartition fooContext1 = fooStartsResolved ? foo : fooUnresolved;
+        FetchContext context1 = newContext(
+            INITIAL,
+            List.of(fooContext1, bar),
+            List.of(),
+            fetchManager,
+            context1Names
+        );
+
+        // So the context contains unresolved bar and a resolved foo iff 
fooStartsResolved
+        assertInstanceOf(FullFetchContext.class, context1);
+        assertPartitionsOrder(context1, List.of(fooContext1, barUnresolved));
+
+        // The response is sent back to create the session.
+        int sessionId = updateAndGenerateResponseDataSessionId(context1);
+
+        // Forget foo, but keep bar. Foo's resolution depends on 
fooEndsResolved and bar stays unresolved.
+        Map<Uuid, String> context2Names = fooEndsResolved ? 
Map.of(foo.topicId(), foo.topic()) : Map.of();
+        TopicIdPartition fooContext2 = fooEndsResolved ? foo : fooUnresolved;
+        FetchContext context2 = newContext(
+            new FetchMetadata(sessionId, 1),
+            List.of(),
+            List.of(fooContext2),
+            fetchManager,
+            context2Names
+        );
+
+        // So foo is removed but not the others.
+        assertInstanceOf(IncrementalFetchContext.class, context2);
+        assertPartitionsOrder(context2, List.of(barUnresolved));
+
+        updateAndGenerateResponseDataSessionId(context2);
+
+        // Now remove bar
+        FetchContext context3 = newContext(
+            new FetchMetadata(sessionId, 2),
+            List.of(),
+            List.of(bar),
+            fetchManager,
+            Map.of()
+        );
+
+        // Context is sessionless since it is empty.
+        assertInstanceOf(SessionlessFetchContext.class, context3);
+        assertPartitionsOrder(context3, List.of());
+    }
+
+    @Test
+    public void testUpdateAndGenerateResponseData() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+
+        // Give both topics errors so they will stay in the session.
+        TopicIdPartition foo = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("foo", 0));
+        TopicIdPartition bar = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("bar", 0));
+
+        // Foo will always be resolved and bar will always not be resolved on 
the receiving broker.
+        Map<Uuid, String> receivingBrokerTopicNames = Map.of(foo.topicId(), 
foo.topic());
+        // The sender will know both topics' id to name mappings.
+        Map<Uuid, String> sendingTopicNames = Map.of(foo.topicId(), 
foo.topic(), bar.topicId(), bar.topic());
+
+        // Start with a sessionless context.
+        FetchContext context1 = newContext(
+            FetchMetadata.LEGACY,
+            List.of(foo, bar),
+            fetchManager,
+            receivingBrokerTopicNames
+        );
+        assertInstanceOf(SessionlessFetchContext.class, context1);
+        // Check the response can be read as expected.
+        checkResponseData(
+            Map.of(
+                foo.topicPartition(), Errors.UNKNOWN_TOPIC_OR_PARTITION.code(),
+                bar.topicPartition(), Errors.UNKNOWN_TOPIC_ID.code()
+            ),
+            updateAndGenerateResponseData(context1),
+            sendingTopicNames
+        );
+
+        // Now create a full context.
+        FetchContext context2 = newContext(
+            INITIAL,
+            List.of(foo, bar),
+            fetchManager,
+            receivingBrokerTopicNames
+        );
+        assertInstanceOf(FullFetchContext.class, context2);
+
+        // We want to get the session ID to build more contexts in this 
session.
+        FetchResponse response2 = updateAndGenerateResponseData(context2);
+        int sessionId = response2.sessionId();
+        checkResponseData(
+            Map.of(
+                foo.topicPartition(), Errors.UNKNOWN_TOPIC_OR_PARTITION.code(),
+                bar.topicPartition(), Errors.UNKNOWN_TOPIC_ID.code()
+            ),
+            response2,
+            sendingTopicNames
+        );
+
+        // Now create an incremental context. We re-add foo as though the 
partition data is updated. In a real broker, the data would update.
+        FetchContext context3 = newContext(
+            new FetchMetadata(sessionId, 1),
+            List.of(),
+            fetchManager,
+            receivingBrokerTopicNames
+        );
+        assertInstanceOf(IncrementalFetchContext.class, context3);
+        checkResponseData(
+            Map.of(
+                foo.topicPartition(), Errors.UNKNOWN_TOPIC_OR_PARTITION.code(),
+                bar.topicPartition(), Errors.UNKNOWN_TOPIC_ID.code()
+            ),
+            updateAndGenerateResponseData(context3),
+            sendingTopicNames
+        );
+
+        // Finally create an error context by using the same epoch
+        FetchContext context4 = newContext(
+            new FetchMetadata(sessionId, 1),
+            List.of(),
+            fetchManager,
+            receivingBrokerTopicNames
+        );
+        assertInstanceOf(SessionErrorContext.class, context4);
+        // The response should be empty.
+        assertEquals(List.of(), 
updateAndGenerateResponseData(context4).data().responses());
+    }
+
+    @Test
+    public void testFetchSessionExpiration() {
+        MockTime time = new MockTime();
+        // set maximum entries to 2 to allow for eviction later
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(2, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(time, cacheShard);
+        Uuid fooId = Uuid.randomUuid();
+        Map<Uuid, String> topicNames = Map.of(fooId, "foo");
+        TopicIdPartition foo0 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 0));
+        TopicIdPartition foo1 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 1));
+
+        // Create a new fetch session, session 1
+        LinkedHashMap<TopicPartition, PartitionData> session1req = new 
LinkedHashMap<>();
+        session1req.put(foo0.topicPartition(), new PartitionData(fooId, 0, 0, 
100, Optional.empty()));
+        session1req.put(foo1.topicPartition(), new PartitionData(fooId, 10, 0, 
100, Optional.empty()));
+        FetchRequest session1request1 = createRequest(INITIAL, session1req, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext session1context1 = newContext(fetchManager, 
session1request1, topicNames);
+        assertInstanceOf(FullFetchContext.class, session1context1);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData1 = new LinkedHashMap<>();
+        respData1.put(foo0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        respData1.put(foo1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse session1resp = 
session1context1.updateAndGenerateResponseData(respData1, List.of());
+        assertEquals(Errors.NONE, session1resp.error());
+        assertTrue(session1resp.sessionId() != INVALID_SESSION_ID);
+        assertEquals(2, session1resp.responseData(topicNames, 
session1request1.version()).size());
+
+        // check session entered into case
+        assertTrue(cacheShard.get(session1resp.sessionId()).isPresent());
+
+        time.sleep(500);
+
+        // Create a second new fetch session
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> session2req 
= new LinkedHashMap<>();
+        session2req.put(foo0.topicPartition(), new PartitionData(fooId, 0, 0, 
100, Optional.empty()));
+        session2req.put(foo1.topicPartition(), new PartitionData(fooId, 10, 0, 
100, Optional.empty()));
+        FetchRequest session2request1 = createRequest(INITIAL, session2req, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext session2context = newContext(fetchManager, 
session2request1, topicNames);
+        assertInstanceOf(FullFetchContext.class, session2context);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
session2RespData = new LinkedHashMap<>();
+        session2RespData.put(foo0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        session2RespData.put(foo1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse session2resp = 
session2context.updateAndGenerateResponseData(session2RespData, List.of());
+        assertEquals(Errors.NONE, session2resp.error());
+        assertTrue(session2resp.sessionId() != INVALID_SESSION_ID);
+        assertEquals(2, session2resp.responseData(topicNames, 
session2request1.version()).size());
+
+        // both newly created entries are present in cache
+        assertTrue(cacheShard.get(session1resp.sessionId()).isPresent());
+        assertTrue(cacheShard.get(session2resp.sessionId()).isPresent());
+
+        time.sleep(500);
+
+        // Create an incremental fetch request for session 1
+        FetchRequest session1request2 = createRequest(new 
FetchMetadata(session1resp.sessionId(), 1), new LinkedHashMap<>(),
+            new ArrayList<>(), false, FETCH.latestVersion());
+        FetchContext context1v2 = newContext(fetchManager, session1request2, 
topicNames);
+        assertInstanceOf(IncrementalFetchContext.class, context1v2);
+
+        // total sleep time will now be large enough that fetch session 1 will 
be evicted if not correctly touched
+        time.sleep(501);
+
+        // create one final session to test that the least recently used entry 
is evicted
+        // the second session should be evicted because the first session was 
incrementally fetched
+        // more recently than the second session was created
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> session3req 
= new LinkedHashMap<>();
+        session3req.put(foo0.topicPartition(), new PartitionData(fooId, 0, 0, 
100, Optional.empty()));
+        session3req.put(foo1.topicPartition(), new PartitionData(fooId, 0, 0, 
100, Optional.empty()));
+        FetchRequest session3request1 = createRequest(INITIAL, session3req, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext session3context = newContext(fetchManager, 
session3request1, topicNames);
+        assertInstanceOf(FullFetchContext.class, session3context);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData3 = new LinkedHashMap<>();
+        respData3.put(new TopicIdPartition(fooId, new TopicPartition("foo", 
0)), new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        respData3.put(new TopicIdPartition(fooId, new TopicPartition("foo", 
1)), new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse session3resp = 
session3context.updateAndGenerateResponseData(respData3, List.of());
+        assertEquals(Errors.NONE, session3resp.error());
+        assertTrue(session3resp.sessionId() != INVALID_SESSION_ID);
+        assertEquals(2, session3resp.responseData(topicNames, 
session3request1.version()).size());
+
+        assertTrue(cacheShard.get(session1resp.sessionId()).isPresent());
+        assertFalse(cacheShard.get(session2resp.sessionId()).isPresent(),
+            "session 2 should have been evicted by latest session, as session 
1 was used more recently");
+        assertTrue(cacheShard.get(session3resp.sessionId()).isPresent());
+    }
+
+    @Test
+    public void testPrivilegedSessionHandling() {
+        MockTime time = new MockTime();
+        // set maximum entries to 2 to allow for eviction later
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(2, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(time, cacheShard);
+        Uuid fooId = Uuid.randomUuid();
+        Map<Uuid, String> topicNames = Map.of(fooId, "foo");
+        TopicIdPartition foo0 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 0));
+        TopicIdPartition foo1 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 1));
+
+        // Create a new fetch session, session 1
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> session1req 
= new LinkedHashMap<>();
+        session1req.put(foo0.topicPartition(), new PartitionData(fooId, 0, 0, 
100, Optional.empty()));
+        session1req.put(foo1.topicPartition(), new PartitionData(fooId, 10, 0, 
100, Optional.empty()));
+        FetchRequest session1request = createRequest(INITIAL, session1req, 
EMPTY_PART_LIST, true, FETCH.latestVersion());
+        FetchContext session1context = newContext(fetchManager, 
session1request, topicNames);
+        assertInstanceOf(FullFetchContext.class, session1context);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData1 = new LinkedHashMap<>();
+        respData1.put(foo0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        respData1.put(foo1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse session1resp = 
session1context.updateAndGenerateResponseData(respData1, List.of());
+        assertEquals(Errors.NONE, session1resp.error());
+        assertTrue(session1resp.sessionId() != INVALID_SESSION_ID);
+        assertEquals(2, session1resp.responseData(topicNames, 
session1request.version()).size());
+        assertEquals(1, cacheShard.size());
+
+        // move time forward to age session 1 a little compared to session 2
+        time.sleep(500);
+
+        // Create a second new fetch session, unprivileged
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> session2req 
= new LinkedHashMap<>();
+        session2req.put(foo0.topicPartition(), new PartitionData(fooId, 0, 0, 
100, Optional.empty()));
+        session2req.put(foo1.topicPartition(), new PartitionData(fooId, 10, 0, 
100, Optional.empty()));
+        FetchRequest session2request = createRequest(INITIAL, session2req, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext session2context = newContext(fetchManager, 
session2request, topicNames);
+        assertInstanceOf(FullFetchContext.class, session2context);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
session2RespData = new LinkedHashMap<>();
+        session2RespData.put(foo0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        session2RespData.put(foo1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse session2resp = 
session2context.updateAndGenerateResponseData(session2RespData, List.of());
+        assertEquals(Errors.NONE, session2resp.error());
+        assertTrue(session2resp.sessionId() != INVALID_SESSION_ID);
+        assertEquals(2, session2resp.responseData(topicNames, 
session2request.version()).size());
+
+        // both newly created entries are present in cache
+        assertTrue(cacheShard.get(session1resp.sessionId()).isPresent());
+        assertTrue(cacheShard.get(session2resp.sessionId()).isPresent());
+        assertEquals(2, cacheShard.size());
+
+        time.sleep(500);
+
+        // create a session to test session1 privileges mean that session 1 is 
retained and session 2 is evicted
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> session3req 
= new LinkedHashMap<>();
+        session3req.put(foo0.topicPartition(), new PartitionData(fooId, 0, 0, 
100, Optional.empty()));
+        session3req.put(foo1.topicPartition(), new PartitionData(fooId, 0, 0, 
100, Optional.empty()));
+        FetchRequest session3request = createRequest(INITIAL, session3req, 
EMPTY_PART_LIST, true, FETCH.latestVersion());
+        FetchContext session3context = newContext(fetchManager, 
session3request, topicNames);
+        assertInstanceOf(FullFetchContext.class, session3context);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData3 = new LinkedHashMap<>();
+        respData3.put(foo0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        respData3.put(foo1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse session3resp = 
session3context.updateAndGenerateResponseData(respData3, List.of());
+        assertEquals(Errors.NONE, session3resp.error());
+        assertTrue(session3resp.sessionId() != INVALID_SESSION_ID);
+        assertEquals(2, session3resp.responseData(topicNames, 
session3request.version()).size());
+
+        assertTrue(cacheShard.get(session1resp.sessionId()).isPresent());
+        // even though session 2 is more recent than session 1, and has not 
reached expiry time, it is less
+        // privileged than session 2, and thus session 3 should be entered and 
session 2 evicted.
+        assertFalse(cacheShard.get(session2resp.sessionId()).isPresent(), 
"session 2 should have been evicted by session 3");
+        assertTrue(cacheShard.get(session3resp.sessionId()).isPresent());
+        assertEquals(2, cacheShard.size());
+
+        time.sleep(501);
+
+        // create a final session to test whether session1 can be evicted due 
to age even though it is privileged
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> session4req 
= new LinkedHashMap<>();
+        session4req.put(foo0.topicPartition(), new PartitionData(fooId, 0, 0, 
100, Optional.empty()));
+        session4req.put(foo1.topicPartition(), new PartitionData(fooId, 0, 0, 
100, Optional.empty()));
+        FetchRequest session4request = createRequest(INITIAL, session4req, 
EMPTY_PART_LIST, true, FETCH.latestVersion());
+        FetchContext session4context = newContext(fetchManager, 
session4request, topicNames);
+        assertInstanceOf(FullFetchContext.class, session4context);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData4 = new LinkedHashMap<>();
+        respData4.put(foo0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        respData4.put(foo1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse session4resp = 
session3context.updateAndGenerateResponseData(respData4, List.of());
+        assertEquals(Errors.NONE, session4resp.error());
+        assertTrue(session4resp.sessionId() != INVALID_SESSION_ID);
+        assertEquals(2, session4resp.responseData(topicNames, 
session4request.version()).size());
+
+        assertFalse(cacheShard.get(session1resp.sessionId()).isPresent(),
+            "session 1 should have been evicted by session 4 even though it is 
privileged as it has hit eviction time");
+        assertTrue(cacheShard.get(session3resp.sessionId()).isPresent());
+        assertTrue(cacheShard.get(session4resp.sessionId()).isPresent());
+        assertEquals(2, cacheShard.size());
+    }
+
+    @Test
+    public void testZeroSizeFetchSession() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+        Uuid fooId = Uuid.randomUuid();
+        Map<Uuid, String> topicNames = Map.of(fooId, "foo");
+        TopicIdPartition foo0 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 0));
+        TopicIdPartition foo1 = new TopicIdPartition(fooId, new 
TopicPartition("foo", 1));
+
+        // Create a new fetch session with foo-0 and foo-1
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqData1 = 
new LinkedHashMap<>();
+        reqData1.put(foo0.topicPartition(), new PartitionData(fooId, 0, 0, 
100, Optional.empty()));
+        reqData1.put(foo1.topicPartition(), new PartitionData(fooId, 10, 0, 
100, Optional.empty()));
+        FetchRequest request1 = createRequest(INITIAL, reqData1, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context1 = newContext(fetchManager, request1, topicNames);
+        assertInstanceOf(FullFetchContext.class, context1);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData1 = new LinkedHashMap<>();
+        respData1.put(foo0, new FetchResponseData.PartitionData()
+            .setPartitionIndex(0)
+            .setHighWatermark(100)
+            .setLastStableOffset(100)
+            .setLogStartOffset(100));
+        respData1.put(foo1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(1)
+            .setHighWatermark(10)
+            .setLastStableOffset(10)
+            .setLogStartOffset(10));
+        FetchResponse resp1 = 
context1.updateAndGenerateResponseData(respData1, List.of());
+        assertEquals(Errors.NONE, resp1.error());
+        assertTrue(resp1.sessionId() != INVALID_SESSION_ID);
+        assertEquals(2, resp1.responseData(topicNames, 
request1.version()).size());
+
+        // Create an incremental fetch request that removes foo-0 and foo-1
+        // Verify that the previous fetch session was closed.
+        FetchRequest request2 = createRequest(new 
FetchMetadata(resp1.sessionId(), 1), new LinkedHashMap<>(),
+            List.of(foo0, foo1), false, FETCH.latestVersion());
+        FetchContext context2 = newContext(fetchManager, request2, topicNames);
+        assertInstanceOf(SessionlessFetchContext.class, context2);
+
+        FetchResponse resp2 = context2.updateAndGenerateResponseData(new 
LinkedHashMap<>(), List.of());
+        assertEquals(INVALID_SESSION_ID, resp2.sessionId());
+        assertTrue(resp2.responseData(topicNames, 
request2.version()).isEmpty());
+        assertEquals(0, cacheShard.size());
+    }
+
+    @Test
+    public void testDivergingEpoch() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+        Map<Uuid, String> topicNames = Map.of(Uuid.randomUuid(), "foo", 
Uuid.randomUuid(), "bar");
+        Map<String, Uuid> topicIds = 
topicNames.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, 
Map.Entry::getKey));
+        TopicIdPartition tp1 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 1));
+        TopicIdPartition tp2 = new TopicIdPartition(topicIds.get("bar"), new 
TopicPartition("bar", 2));
+
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqData = 
new LinkedHashMap<>();
+        reqData.put(tp1.topicPartition(), new PartitionData(tp1.topicId(), 
100, 0, 1000, Optional.of(5), Optional.of(4)));
+        reqData.put(tp2.topicPartition(), new PartitionData(tp2.topicId(), 
100, 0, 1000, Optional.of(5), Optional.of(4)));
+
+        // Full fetch context returns all partitions in the response
+        FetchRequest request1 = createRequest(INITIAL, reqData, 
EMPTY_PART_LIST, false, FETCH.latestVersion());
+        FetchContext context1 = newContext(fetchManager, request1, topicNames);
+        assertInstanceOf(FullFetchContext.class, context1);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData = new LinkedHashMap<>();
+        respData.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp1.partition())
+            .setHighWatermark(105)
+            .setLastStableOffset(105)
+            .setLogStartOffset(0));
+        FetchResponseData.EpochEndOffset divergingEpoch = new 
FetchResponseData.EpochEndOffset().setEpoch(3).setEndOffset(90);
+        respData.put(tp2, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp2.partition())
+            .setHighWatermark(105)
+            .setLastStableOffset(105)
+            .setLogStartOffset(0)
+            .setDivergingEpoch(divergingEpoch));
+        FetchResponse resp1 = context1.updateAndGenerateResponseData(respData, 
List.of());
+        assertEquals(Errors.NONE, resp1.error());
+        assertNotEquals(INVALID_SESSION_ID, resp1.sessionId());
+        assertEquals(Set.of(tp1.topicPartition(), tp2.topicPartition()), 
resp1.responseData(topicNames, request1.version()).keySet());
+
+        // Incremental fetch context returns partitions with divergent epoch 
even if none
+        // of the other conditions for return are met.
+        FetchRequest request2 = createRequest(new 
FetchMetadata(resp1.sessionId(), 1), reqData, EMPTY_PART_LIST, false, 
FETCH.latestVersion());
+        FetchContext context2 = newContext(fetchManager, request2, topicNames);
+        assertInstanceOf(IncrementalFetchContext.class, context2);
+
+        FetchResponse resp2 = context2.updateAndGenerateResponseData(respData, 
List.of());
+        assertEquals(Errors.NONE, resp2.error());
+        assertEquals(resp1.sessionId(), resp2.sessionId());
+        assertEquals(Set.of(tp2.topicPartition()), 
resp2.responseData(topicNames, request2.version()).keySet());
+
+        // All partitions with divergent epoch should be returned.
+        respData.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp1.partition())
+            .setHighWatermark(105)
+            .setLastStableOffset(105)
+            .setLogStartOffset(0)
+            .setDivergingEpoch(divergingEpoch));
+        FetchResponse resp3 = context2.updateAndGenerateResponseData(respData, 
List.of());
+        assertEquals(Errors.NONE, resp3.error());
+        assertEquals(resp1.sessionId(), resp3.sessionId());
+        assertEquals(Set.of(tp1.topicPartition(), tp2.topicPartition()), 
resp3.responseData(topicNames, request2.version()).keySet());
+
+        // Partitions that meet other conditions should be returned regardless 
of whether
+        // divergingEpoch is set or not.
+        respData.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp1.partition())
+            .setHighWatermark(110)
+            .setLastStableOffset(110)
+            .setLogStartOffset(0));
+        FetchResponse resp4 = context2.updateAndGenerateResponseData(respData, 
List.of());
+        assertEquals(Errors.NONE, resp4.error());
+        assertEquals(resp1.sessionId(), resp4.sessionId());
+        assertEquals(Set.of(tp1.topicPartition(), tp2.topicPartition()), 
resp4.responseData(topicNames, request2.version()).keySet());
+    }
+
+    @Test
+    public void testDeprioritizesPartitionsWithRecordsOnly() {
+        FetchSessionCacheShard cacheShard = new FetchSessionCacheShard(10, 
1000, Integer.MAX_VALUE, 0);
+        FetchManager fetchManager = new FetchManager(new MockTime(), 
cacheShard);
+        Map<String, Uuid> topicIds = Map.of("foo", Uuid.randomUuid(), "bar", 
Uuid.randomUuid(), "zar", Uuid.randomUuid());
+        Map<Uuid, String> topicNames = 
topicIds.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, 
Map.Entry::getKey));
+        TopicIdPartition tp1 = new TopicIdPartition(topicIds.get("foo"), new 
TopicPartition("foo", 1));
+        TopicIdPartition tp2 = new TopicIdPartition(topicIds.get("bar"), new 
TopicPartition("bar", 2));
+        TopicIdPartition tp3 = new TopicIdPartition(topicIds.get("zar"), new 
TopicPartition("zar", 3));
+
+        LinkedHashMap<TopicIdPartition, FetchRequest.PartitionData> reqData = 
new LinkedHashMap<>();
+        reqData.put(tp1, new PartitionData(tp1.topicId(), 100, 0, 1000, 
Optional.of(5), Optional.of(4)));
+        reqData.put(tp2, new PartitionData(tp2.topicId(), 100, 0, 1000, 
Optional.of(5), Optional.of(4)));
+        reqData.put(tp3, new PartitionData(tp3.topicId(), 100, 0, 1000, 
Optional.of(5), Optional.of(4)));
+
+        // Full fetch context returns all partitions in the response
+        FetchContext context1 = fetchManager.newContext(FETCH.latestVersion(), 
INITIAL, false, reqData, List.of(), topicNames);
+        assertInstanceOf(FullFetchContext.class, context1);
+
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData1 = new LinkedHashMap<>();
+        respData1.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp1.topicPartition().partition())
+            .setHighWatermark(50)
+            .setLastStableOffset(50)
+            .setLogStartOffset(0));
+        respData1.put(tp2, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp2.topicPartition().partition())
+            .setHighWatermark(50)
+            .setLastStableOffset(50)
+            .setLogStartOffset(0));
+        respData1.put(tp3, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp3.topicPartition().partition())
+            .setHighWatermark(50)
+            .setLastStableOffset(50)
+            .setLogStartOffset(0));
+
+        FetchResponse resp1 = 
context1.updateAndGenerateResponseData(respData1, List.of());
+        assertEquals(Errors.NONE, resp1.error());
+        assertNotEquals(INVALID_SESSION_ID, resp1.sessionId());
+        assertEquals(
+            Set.of(tp1.topicPartition(), tp2.topicPartition(), 
tp3.topicPartition()),
+            resp1.responseData(topicNames, FETCH.latestVersion()).keySet()
+        );
+
+        // Incremental fetch context returns partitions with changes but only 
deprioritizes
+        // the partitions with records
+        FetchContext context2 = fetchManager.newContext(FETCH.latestVersion(), 
new FetchMetadata(resp1.sessionId(), 1),
+            false, reqData, List.of(), topicNames);
+        assertInstanceOf(IncrementalFetchContext.class, context2);
+
+        // Partitions are ordered in the session as per last response
+        assertPartitionsOrder(context2, List.of(tp1, tp2, tp3));
+
+        // Response is empty
+        FetchResponse resp2 = context2.updateAndGenerateResponseData(new 
LinkedHashMap<>(), List.of());
+        assertEquals(Errors.NONE, resp2.error());
+        assertEquals(resp1.sessionId(), resp2.sessionId());
+        assertEquals(Set.of(), resp2.responseData(topicNames, 
FETCH.latestVersion()).keySet());
+
+        // All partitions with changes should be returned.
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> 
respData3 = new LinkedHashMap<>();
+        respData3.put(tp1, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp1.topicPartition().partition())
+            .setHighWatermark(60)
+            .setLastStableOffset(50)
+            .setLogStartOffset(0));
+        respData3.put(tp2, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp2.topicPartition().partition())
+            .setHighWatermark(60)
+            .setLastStableOffset(50)
+            .setLogStartOffset(0)
+            .setRecords(MemoryRecords.withRecords(Compression.NONE, new 
SimpleRecord(100, null))));
+        respData3.put(tp3, new FetchResponseData.PartitionData()
+            .setPartitionIndex(tp3.topicPartition().partition())
+            .setHighWatermark(50)
+            .setLastStableOffset(50)
+            .setLogStartOffset(0));
+        FetchResponse resp3 = 
context2.updateAndGenerateResponseData(respData3, List.of());
+        assertEquals(Errors.NONE, resp3.error());
+        assertEquals(resp1.sessionId(), resp3.sessionId());
+        assertEquals(Set.of(tp1.topicPartition(), tp2.topicPartition()), 
resp3.responseData(topicNames, FETCH.latestVersion()).keySet());
+
+        // Only the partitions whose returned records in the last response
+        // were deprioritized
+        assertPartitionsOrder(context2, List.of(tp1, tp3, tp2));
+    }
+
+    @Test
+    public void testCachedPartitionEqualsAndHashCode() {
+        Uuid topicId = Uuid.randomUuid();
+        String topicName = "topic";
+        int partition = 0;
+
+        CachedPartition cachedPartitionWithIdAndName = new 
CachedPartition(topicName, topicId, partition);
+        CachedPartition cachedPartitionWithIdAndNoName = new 
CachedPartition(null, topicId, partition);
+        CachedPartition cachedPartitionWithDifferentIdAndName = new 
CachedPartition(topicName, Uuid.randomUuid(), partition);
+        CachedPartition cachedPartitionWithZeroIdAndName = new 
CachedPartition(topicName, Uuid.ZERO_UUID, partition);
+        CachedPartition cachedPartitionWithZeroIdAndOtherName = new 
CachedPartition("otherTopic", Uuid.ZERO_UUID, partition);
+
+        // CachedPartitions with valid topic IDs will compare topic ID and 
partition but not topic name.
+        assertEquals(cachedPartitionWithIdAndName, 
cachedPartitionWithIdAndNoName);
+        assertEquals(cachedPartitionWithIdAndName.hashCode(), 
cachedPartitionWithIdAndNoName.hashCode());
+
+        assertNotEquals(cachedPartitionWithIdAndName, 
cachedPartitionWithDifferentIdAndName);
+        assertNotEquals(cachedPartitionWithIdAndName.hashCode(), 
cachedPartitionWithDifferentIdAndName.hashCode());
+
+        assertNotEquals(cachedPartitionWithIdAndName, 
cachedPartitionWithZeroIdAndName);
+        assertNotEquals(cachedPartitionWithIdAndName.hashCode(), 
cachedPartitionWithZeroIdAndName.hashCode());
+
+        // CachedPartitions will null name and valid IDs will act just like 
ones with valid names
+        assertEquals(cachedPartitionWithIdAndNoName, 
cachedPartitionWithIdAndName);
+        assertEquals(cachedPartitionWithIdAndNoName.hashCode(), 
cachedPartitionWithIdAndName.hashCode());
+
+        assertNotEquals(cachedPartitionWithIdAndNoName, 
cachedPartitionWithDifferentIdAndName);
+        assertNotEquals(cachedPartitionWithIdAndNoName.hashCode(), 
cachedPartitionWithDifferentIdAndName.hashCode());
+
+        assertNotEquals(cachedPartitionWithIdAndNoName, 
cachedPartitionWithZeroIdAndName);
+        assertNotEquals(cachedPartitionWithIdAndNoName.hashCode(), 
cachedPartitionWithZeroIdAndName.hashCode());
+
+        // CachedPartition with zero Uuids will compare topic name and 
partition.
+        assertNotEquals(cachedPartitionWithZeroIdAndName, 
cachedPartitionWithZeroIdAndOtherName);
+        assertNotEquals(cachedPartitionWithZeroIdAndName.hashCode(), 
cachedPartitionWithZeroIdAndOtherName.hashCode());
+
+        assertEquals(cachedPartitionWithZeroIdAndName, 
cachedPartitionWithZeroIdAndName);
+        assertEquals(cachedPartitionWithZeroIdAndName.hashCode(), 
cachedPartitionWithZeroIdAndName.hashCode());
+    }
+
+    @Test
+    public void testMaybeResolveUnknownName() {
+        CachedPartition namedPartition = new CachedPartition("topic", 
Uuid.randomUuid(), 0);
+        CachedPartition nullNamePartition1 = new CachedPartition(null, 
Uuid.randomUuid(), 0);
+        CachedPartition nullNamePartition2 = new CachedPartition(null, 
Uuid.randomUuid(), 0);
+        Map<Uuid, String> topicNames = Map.of(namedPartition.topicId(), "foo", 
nullNamePartition1.topicId(), "bar");
+
+        // Since the name is not null, we should not change the topic name.
+        // We should never have a scenario where the same ID is used by two 
topic names, but this is used to test we respect the null check.
+        namedPartition.maybeResolveUnknownName(topicNames);
+        assertEquals("topic", namedPartition.topic());
+
+        // We will resolve this name as it is in the map and the current name 
is null.
+        nullNamePartition1.maybeResolveUnknownName(topicNames);
+        assertEquals("bar", nullNamePartition1.topic());
+
+        // If the ID is not in the map, then we don't resolve the name.
+        nullNamePartition2.maybeResolveUnknownName(topicNames);
+        assertNull(nullNamePartition2.topic());
+    }
+
+    @Test
+    public void 
testFetchSessionCache_getShardedCache_retrievesCacheFromCorrectSegment() {
+        // Given
+        int numShards = 8;
+        int sessionIdRange = Integer.MAX_VALUE / numShards;
+        List<FetchSessionCacheShard> cacheShards = IntStream.range(0, 
numShards)
+            .mapToObj(shardNum -> new FetchSessionCacheShard(10, 1000, 
sessionIdRange, shardNum))
+            .toList();
+        FetchSessionCache cache = new FetchSessionCache(cacheShards);
+
+        // When
+        FetchSessionCacheShard cache0 = cache.getCacheShard(sessionIdRange - 
1);
+        FetchSessionCacheShard cache1 = cache.getCacheShard(sessionIdRange);
+        FetchSessionCacheShard cache2 = cache.getCacheShard(sessionIdRange * 
2);
+
+        // Then
+        assertEquals(cache0, cacheShards.get(0));
+        assertEquals(cache1, cacheShards.get(1));
+        assertEquals(cache2, cacheShards.get(2));
+        assertThrows(IndexOutOfBoundsException.class, () -> 
cache.getCacheShard(sessionIdRange * numShards));
+    }
+
+    @Test
+    public void testFetchSessionCache_RoundRobinsIntoShards() {
+        // Given
+        int numShards = 8;
+        int sessionIdRange = Integer.MAX_VALUE / numShards;
+        List<FetchSessionCacheShard> cacheShards = IntStream.range(0, 
numShards)
+            .mapToObj(shardNum -> new FetchSessionCacheShard(10, 1000, 
sessionIdRange, shardNum))
+            .toList();
+        FetchSessionCache cache = new FetchSessionCache(cacheShards);
+
+        // When / Then
+        for (int shardNum = 0; shardNum < numShards * 2; shardNum++)
+            assertEquals(cacheShards.get(shardNum % numShards), 
cache.getNextCacheShard());
+    }
+
+    @Test
+    public void 
testFetchSessionCache_RoundRobinsIntoShards_WhenIntegerOverflows() {
+        // Given
+        int maxInteger = Integer.MAX_VALUE;
+        FetchSessionCache.COUNTER.set(maxInteger + 1);
+        int numShards = 8;
+        int sessionIdRange = Integer.MAX_VALUE / numShards;
+        List<FetchSessionCacheShard> cacheShards = IntStream.range(0, 
numShards)
+            .mapToObj(shardNum -> new FetchSessionCacheShard(10, 1000, 
sessionIdRange, shardNum))
+            .toList();
+        FetchSessionCache cache = new FetchSessionCache(cacheShards);
+
+        // When / Then
+        for (int shardNum = 0; shardNum < numShards * 2; shardNum++)
+            assertEquals(cacheShards.get(shardNum % numShards), 
cache.getNextCacheShard());
+    }
+
+    private void assertCacheContains(FetchSessionCacheShard cacheShard, int... 
sessionIds) {
+        int i = 0;
+        for (int sessionId : sessionIds) {
+            i = i + 1;
+            assertTrue(cacheShard.get(sessionId).isPresent(),
+                "Missing session " + i + " out of " + 
List.of(sessionIds).size() + " " + sessionId + "\"");
+        }
+        assertEquals(sessionIds.length, cacheShard.size());
+    }
+
+    private ImplicitLinkedHashCollection<CachedPartition> dummyCreate(int 
size) {

Review Comment:
   Renamed it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to