This is an automated email from the ASF dual-hosted git repository.

mittal pushed a commit to branch 4.2
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/4.2 by this push:
     new 0e4f27dd843 KAFKA-19950: Mark session stale when removed prior 
disconnection event (#21062) (#21071)
0e4f27dd843 is described below

commit 0e4f27dd843917de6b89fad6551d9ba18f33e76a
Author: Apoorv Mittal <[email protected]>
AuthorDate: Wed Dec 3 19:27:49 2025 +0000

    KAFKA-19950: Mark session stale when removed prior disconnection event 
(#21062) (#21071)
    
    The PR marks the old session stale in connection listener map to avoid
    triggering the member leave event. This is more often see when client
    sends the share fetch request again with `initial epoch`, then broker
    refreshes the connection.
    
    Reviewers: Andrew Schofield <[email protected]>, Abhinav Dixit
     <[email protected]>
---
 core/src/main/scala/kafka/server/KafkaApis.scala   |   4 +-
 .../scala/unit/kafka/server/KafkaApisTest.scala    | 133 +++++++++++----------
 .../kafka/server/share/session/ShareSession.java   |  17 ++-
 .../server/share/session/ShareSessionCache.java    | 133 ++++++++++++++++-----
 .../share/session/ShareSessionCacheTest.java       |  40 +++++++
 5 files changed, 233 insertions(+), 94 deletions(-)

diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala 
b/core/src/main/scala/kafka/server/KafkaApis.scala
index affdced8cf4..c3667d00154 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -3370,7 +3370,7 @@ class KafkaApis(val requestChannel: RequestChannel,
                   error(s"Releasing share session close with correlation from 
client ${request.header.clientId}  " +
                     s"failed with error ${throwable.getMessage}")
                 } else {
-                  info(s"Releasing share session close 
$releaseAcquiredRecordsData succeeded")
+                  info(s"Releasing share session for client id 
${request.header.clientId} succeeded, response: $releaseAcquiredRecordsData")
                 }
               )
           }
@@ -3594,7 +3594,7 @@ class KafkaApis(val requestChannel: RequestChannel,
                   debug(s"Releasing share session close with correlation from 
client ${request.header.clientId}  " +
                     s"failed with error ${throwable.getMessage}")
                 } else {
-                  info(s"Releasing share session close 
$releaseAcquiredRecordsData succeeded")
+                  info(s"Releasing share session for client id 
${request.header.clientId} succeeded, response: $releaseAcquiredRecordsData")
                 }
               }
           }
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala 
b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index dc0dbe11bdd..48e35ddf6bc 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -4777,12 +4777,6 @@ class KafkaApisTest extends Logging {
     cachedSharePartitions.mustAdd(new CachedSharePartition(
       new TopicIdPartition(topicId, partitionIndex, topicName), false))
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any())).thenThrow(
-      Errors.INVALID_REQUEST.exception()
-    ).thenReturn(new ShareSessionContext(1, new ShareSession(
-      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2
-    )))
-
     when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
       any[Session](), anyString, anyDouble, anyLong)).thenReturn(0)
 
@@ -4806,6 +4800,13 @@ class KafkaApisTest extends Logging {
 
     var shareFetchRequest = new 
ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
     var request = buildRequest(shareFetchRequest)
+
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any())).thenThrow(
+      Errors.INVALID_REQUEST.exception()
+    ).thenReturn(new ShareSessionContext(1, new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2, 
request.context.connectionId)
+    ))
+
     kafkaApis = createKafkaApis()
     kafkaApis.handleShareFetchRequest(request)
     var response = verifyNoThrottling[ShareFetchResponse](request)
@@ -5024,11 +5025,6 @@ class KafkaApisTest extends Logging {
     cachedSharePartitions.mustAdd(new CachedSharePartition(
       new TopicIdPartition(topicId, partitionIndex, topicName), false))
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any()))
-      .thenReturn(new ShareSessionContext(1, new ShareSession(
-        new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2))
-      )
-
     when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
       any[Session](), anyString, anyDouble, anyLong)).thenReturn(0)
 
@@ -5052,6 +5048,12 @@ class KafkaApisTest extends Logging {
 
     val shareFetchRequest = new 
ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
     val request = buildRequest(shareFetchRequest)
+
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any()))
+      .thenReturn(new ShareSessionContext(1, new ShareSession(
+        new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2, 
request.context.connectionId))
+      )
+
     kafkaApis = createKafkaApis()
     kafkaApis.handleShareFetchRequest(request)
     val response = verifyNoThrottling[ShareFetchResponse](request)
@@ -5083,11 +5085,6 @@ class KafkaApisTest extends Logging {
     cachedSharePartitions.mustAdd(new CachedSharePartition(
       new TopicIdPartition(topicId, partitionIndex, topicName), false))
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any()))
-      .thenReturn(new ShareSessionContext(1, new ShareSession(
-        new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2))
-      )
-
     when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
       any[Session](), anyString, anyDouble, anyLong)).thenReturn(0)
 
@@ -5111,6 +5108,12 @@ class KafkaApisTest extends Logging {
 
     val shareFetchRequest = new 
ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
     val request = buildRequest(shareFetchRequest)
+
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any()))
+      .thenReturn(new ShareSessionContext(1, new ShareSession(
+        new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2, 
request.context.connectionId))
+      )
+
     kafkaApis = createKafkaApis()
     kafkaApis.handleShareFetchRequest(request)
     val response = verifyNoThrottling[ShareFetchResponse](request)
@@ -5460,16 +5463,6 @@ class KafkaApisTest extends Logging {
       new TopicIdPartition(topicId, partitionIndex, topicName), false)
     )
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any())).thenReturn(
-      new ShareSessionContext(0, util.List.of(
-        new TopicIdPartition(topicId, partitionIndex, topicName)
-      ))
-    ).thenReturn(new ShareSessionContext(1, new ShareSession(
-      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2))
-    ).thenReturn(new ShareSessionContext(2, new ShareSession(
-      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 3))
-    )
-
     when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
       any[Session](), anyString, anyDouble, anyLong)).thenReturn(0)
 
@@ -5486,6 +5479,16 @@ class KafkaApisTest extends Logging {
     var shareFetchRequest = new 
ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
     var request = buildRequest(shareFetchRequest)
 
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any())).thenReturn(
+      new ShareSessionContext(0, util.List.of(
+        new TopicIdPartition(topicId, partitionIndex, topicName)
+      ))
+    ).thenReturn(new ShareSessionContext(1, new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2, 
request.context.connectionId))
+    ).thenReturn(new ShareSessionContext(2, new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 3, 
request.context.connectionId))
+    )
+
     // First share fetch request is to establish the share session with the 
broker.
     kafkaApis = createKafkaApis()
     kafkaApis.handleShareFetchRequest(request)
@@ -5725,19 +5728,6 @@ class KafkaApisTest extends Logging {
       new TopicIdPartition(topicId4, 0, topicName4), false
     ))
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any())).thenReturn(
-      new ShareSessionContext(0, util.List.of(
-        new TopicIdPartition(topicId1, new TopicPartition(topicName1, 0)),
-        new TopicIdPartition(topicId1, new TopicPartition(topicName1, 1)),
-        new TopicIdPartition(topicId2, new TopicPartition(topicName2, 0)),
-        new TopicIdPartition(topicId2, new TopicPartition(topicName2, 1))
-      ))
-    ).thenReturn(new ShareSessionContext(1, new ShareSession(
-      new ShareSessionKey(groupId, memberId), cachedSharePartitions1, 2))
-    ).thenReturn(new ShareSessionContext(2, new ShareSession(
-      new ShareSessionKey(groupId, memberId), cachedSharePartitions2, 3))
-    ).thenReturn(new FinalContext())
-
     when(sharePartitionManager.releaseSession(any(), any())).thenReturn(
       CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, 
ShareAcknowledgeResponseData.PartitionData](
         new TopicIdPartition(topicId3, new TopicPartition(topicName3, 0)),
@@ -5808,6 +5798,20 @@ class KafkaApisTest extends Logging {
 
     var shareFetchRequest = new 
ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
     var request = buildRequest(shareFetchRequest)
+
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any())).thenReturn(
+      new ShareSessionContext(0, util.List.of(
+        new TopicIdPartition(topicId1, new TopicPartition(topicName1, 0)),
+        new TopicIdPartition(topicId1, new TopicPartition(topicName1, 1)),
+        new TopicIdPartition(topicId2, new TopicPartition(topicName2, 0)),
+        new TopicIdPartition(topicId2, new TopicPartition(topicName2, 1))
+      ))
+    ).thenReturn(new ShareSessionContext(1, new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions1, 2, 
request.context.connectionId))
+    ).thenReturn(new ShareSessionContext(2, new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions2, 3, 
request.context.connectionId))
+    ).thenReturn(new FinalContext())
+
     // First share fetch request is to establish the share session with the 
broker.
     kafkaApis = createKafkaApis()
     kafkaApis.handleShareFetchRequest(request)
@@ -6688,14 +6692,6 @@ class KafkaApisTest extends Logging {
       new TopicIdPartition(topicId, 0, topicName), false
     ))
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any())).thenReturn(
-      new ShareSessionContext(0, util.List.of(
-        new TopicIdPartition(topicId, partitionIndex, topicName)
-      ))
-    ).thenReturn(new ShareSessionContext(1, new ShareSession(
-      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2))
-    )
-
     when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
       any[Session](), anyString, anyDouble, anyLong)).thenReturn(0)
 
@@ -6720,6 +6716,15 @@ class KafkaApisTest extends Logging {
 
     var shareFetchRequest = new 
ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
     var request = buildRequest(shareFetchRequest)
+
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any())).thenReturn(
+      new ShareSessionContext(0, util.List.of(
+        new TopicIdPartition(topicId, partitionIndex, topicName)
+      ))
+    ).thenReturn(new ShareSessionContext(1, new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2, 
request.context.connectionId))
+    )
+
     kafkaApis = createKafkaApis()
     kafkaApis.handleShareFetchRequest(request)
     var response = verifyNoThrottling[ShareFetchResponse](request)
@@ -6807,14 +6812,6 @@ class KafkaApisTest extends Logging {
       new TopicIdPartition(topicId, 0, topicName), false
     ))
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any())).thenReturn(
-      new ShareSessionContext(0, util.List.of(
-        new TopicIdPartition(topicId, partitionIndex, topicName)
-      ))
-    ).thenReturn(new ShareSessionContext(1, new ShareSession(
-      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2))
-    )
-
     when(sharePartitionManager.acknowledge(any(), any(), any())).thenReturn(
       CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, 
ShareAcknowledgeResponseData.PartitionData](
         new TopicIdPartition(topicId, new TopicPartition(topicName, 0)),
@@ -6837,6 +6834,15 @@ class KafkaApisTest extends Logging {
 
     var shareFetchRequest = new 
ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
     var request = buildRequest(shareFetchRequest)
+
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any())).thenReturn(
+      new ShareSessionContext(0, util.List.of(
+        new TopicIdPartition(topicId, partitionIndex, topicName)
+      ))
+    ).thenReturn(new ShareSessionContext(1, new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2, 
request.context.connectionId))
+    )
+
     kafkaApis = createKafkaApis()
     kafkaApis.handleShareFetchRequest(request)
     var response = verifyNoThrottling[ShareFetchResponse](request)
@@ -14299,14 +14305,6 @@ class KafkaApisTest extends Logging {
       new TopicIdPartition(topicId, 0, topicName), false
     ))
 
-    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any())).thenReturn(
-      new ShareSessionContext(0, util.List.of(
-        new TopicIdPartition(topicId, partitionIndex, topicName)
-      ))
-    ).thenReturn(new ShareSessionContext(1, new ShareSession(
-      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2))
-    )
-
     when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
       any[Session](), anyString, anyDouble, anyLong)).thenReturn(0)
 
@@ -14331,6 +14329,15 @@ class KafkaApisTest extends Logging {
 
     var shareFetchRequest = new 
ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
     var request = buildRequest(shareFetchRequest)
+
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any(), any())).thenReturn(
+      new ShareSessionContext(0, util.List.of(
+        new TopicIdPartition(topicId, partitionIndex, topicName)
+      ))
+    ).thenReturn(new ShareSessionContext(1, new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2, 
request.context.connectionId))
+    )
+
     kafkaApis = createKafkaApis()
     kafkaApis.handleShareFetchRequest(request)
     var response = verifyNoThrottling[ShareFetchResponse](request)
diff --git 
a/server/src/main/java/org/apache/kafka/server/share/session/ShareSession.java 
b/server/src/main/java/org/apache/kafka/server/share/session/ShareSession.java
index 5cb800c5524..7f9a47e29df 100644
--- 
a/server/src/main/java/org/apache/kafka/server/share/session/ShareSession.java
+++ 
b/server/src/main/java/org/apache/kafka/server/share/session/ShareSession.java
@@ -38,6 +38,7 @@ public class ShareSession {
 
     private final ShareSessionKey key;
     private final ImplicitLinkedHashCollection<CachedSharePartition> 
partitionMap;
+    private final String connectionId;
 
     // visible for testing
     public int epoch;
@@ -48,16 +49,23 @@ public class ShareSession {
     /**
      * The share session.
      * Each share session is protected by its own lock, which must be taken 
before mutable
-     * fields are read or modified.  This includes modification of the share 
session partition map.
+     * fields are read or modified. This includes modification of the share 
session partition map.
      *
      * @param key                The share session key to identify the share 
session uniquely.
      * @param partitionMap       The CachedPartitionMap.
      * @param epoch              The share session sequence number.
+     * @param connectionId       The connection id associated with this share 
session.
      */
-    public ShareSession(ShareSessionKey key, 
ImplicitLinkedHashCollection<CachedSharePartition> partitionMap, int epoch) {
+    public ShareSession(
+        ShareSessionKey key,
+        ImplicitLinkedHashCollection<CachedSharePartition> partitionMap,
+        int epoch,
+        String connectionId
+    ) {
         this.key = key;
         this.partitionMap = partitionMap;
         this.epoch = epoch;
+        this.connectionId = connectionId;
     }
 
     public ShareSessionKey key() {
@@ -85,6 +93,10 @@ public class ShareSession {
         return partitionMap.isEmpty();
     }
 
+    public String connectionId() {
+        return connectionId;
+    }
+
     // Update the cached partition data based on the request.
     public synchronized Map<ModifiedTopicIdPartitionType, 
List<TopicIdPartition>> update(
         List<TopicIdPartition> shareFetchData,
@@ -138,6 +150,7 @@ public class ShareSession {
                 ", partitionMap=" + partitionMap +
                 ", epoch=" + epoch +
                 ", cachedSize=" + cachedSize +
+                ", connectionId=" + connectionId +
                 ")";
     }
 }
diff --git 
a/server/src/main/java/org/apache/kafka/server/share/session/ShareSessionCache.java
 
b/server/src/main/java/org/apache/kafka/server/share/session/ShareSessionCache.java
index 7a82098712c..01051698039 100644
--- 
a/server/src/main/java/org/apache/kafka/server/share/session/ShareSessionCache.java
+++ 
b/server/src/main/java/org/apache/kafka/server/share/session/ShareSessionCache.java
@@ -67,7 +67,7 @@ public class ShareSessionCache {
      * The map to store the client connection id to session key. This is used 
to remove the session
      * from the cache when the respective client disconnects.
      */
-    private final Map<String, ShareSessionKey> connectionIdToSessionMap;
+    private final Map<String, SessionKeyAndState> connectionIdToSessionMap;
     /**
      * The listener for share group events. This is used to notify the 
listener when the group members
      * change.
@@ -113,6 +113,7 @@ public class ShareSessionCache {
         sessions.clear();
         numMembersPerGroup.clear();
         numPartitions = 0;
+        // Avoid cleaning up connectionIdToSessionMap as that map is cleaned 
when the client disconnects.
     }
 
     public synchronized long totalPartitions() {
@@ -127,14 +128,23 @@ public class ShareSessionCache {
     }
 
     /**
-     * Maybe remove the session and notify listeners. This is called when the 
connection is disconnected
-     * for the client. The session may have already been removed by the client 
as part of final epoch,
-     * hence check if the session is still present in the cache.
+     * Maybe remove the session and notify member leave listener. This is 
called when the connection
+     * is disconnected for the client. The session may have already been 
removed by the client as part
+     * of final epoch, hence check if the session is still present in the 
cache.
      *
      * @param key The share session key.
      */
-    public synchronized void maybeRemoveAndNotifyListeners(ShareSessionKey 
key) {
-        ShareSession session = get(key);
+    private void maybeRemoveAndNotifyListenersOnMemberLeave(ShareSessionKey 
key) {
+        ShareSession session;
+        synchronized (this) {
+            session = get(key);
+            if (session != null) {
+                // As session is not null hence it's removed as part of 
connection disconnect. Hence,
+                // update the evictions metric.
+                evictionsMeter.mark();
+            }
+        }
+
         if (session != null) {
             // Notify the share group listener that member has left the group. 
Notify listener prior
             // removing the session from the cache to ensure that the listener 
has access to the session
@@ -142,23 +152,10 @@ public class ShareSessionCache {
             if (shareGroupListener != null) {
                 shareGroupListener.onMemberLeave(key.groupId(), 
key.memberId());
             }
-            // As session is not null hence it's removed as part of connection 
disconnect. Hence,
-            // update the evictions metric.
-            evictionsMeter.mark();
             // Try removing session if not already removed. The listener might 
have removed the session
             // already.
             remove(session);
         }
-        // Notify the share group listener if the group is empty. This should 
be checked regardless
-        // session is evicted by connection disconnect or client's final epoch.
-        int numMembers = numMembersPerGroup.getOrDefault(key.groupId(), 0);
-        if (numMembers == 0) {
-            // Remove the group from the map as it is empty.
-            numMembersPerGroup.remove(key.groupId());
-            if (shareGroupListener != null) {
-                shareGroupListener.onGroupEmpty(key.groupId());
-            }
-        }
     }
 
     /**
@@ -172,6 +169,17 @@ public class ShareSessionCache {
         if (removeResult != null) {
             numPartitions = numPartitions - session.cachedSize();
             numMembersPerGroup.compute(session.key().groupId(), (k, v) -> v != 
null ? v - 1 : 0);
+            // Mark the session as stale in the connectionIdToSessionMap to 
avoid removing
+            // the active session for the client. When client re-sends the 
initial epoch where the
+            // broker removes the prior session and establishes new session, 
then sometimes the connection
+            // id is changed. This leads to removal of the new session from 
the cache, when old connection
+            // disconnect event is processed. Marking the old connection as 
stale avoids this issue.
+            // If the connection id remains same for both old and new session, 
then the subsequent call
+            // for session creation will overwrite the prior stale mapping in 
the connectionIdToSessionMap.
+            SessionKeyAndState sessionKeyAndState = 
connectionIdToSessionMap.get(session.connectionId());
+            if (sessionKeyAndState != null) {
+                sessionKeyAndState.markStale();
+            }
         }
         return removeResult;
     }
@@ -201,11 +209,11 @@ public class ShareSessionCache {
     ) {
         if (sessions.size() < maxEntries) {
             ShareSession session = new ShareSession(new 
ShareSessionKey(groupId, memberId), partitionMap,
-                
ShareRequestMetadata.nextEpoch(ShareRequestMetadata.INITIAL_EPOCH));
+                
ShareRequestMetadata.nextEpoch(ShareRequestMetadata.INITIAL_EPOCH), 
clientConnectionId);
             sessions.put(session.key(), session);
             updateNumPartitions(session);
             numMembersPerGroup.compute(session.key().groupId(), (k, v) -> v != 
null ? v + 1 : 1);
-            connectionIdToSessionMap.put(clientConnectionId, session.key());
+            connectionIdToSessionMap.put(clientConnectionId, new 
SessionKeyAndState(session.key()));
             return session.key();
         }
         return null;
@@ -219,6 +227,39 @@ public class ShareSessionCache {
         this.shareGroupListener = shareGroupListener;
     }
 
+    /**
+     * Remove the connection id to session mapping when the connection is 
closed.
+     *
+     * @param connectionId The client connection id.
+     * @return The session key, or null if no such mapping was found.
+     */
+    private synchronized SessionKeyAndState 
maybeRemoveConnectionFromSession(String connectionId) {
+        return connectionIdToSessionMap.remove(connectionId);
+    }
+
+    /**
+     * Check if the share group is empty and notify the share group listener.
+     *
+     * @param groupId The share group id.
+     */
+    private void checkAndNotifyListenersOnGroupEmpty(String groupId) {
+        boolean notify = false;
+        synchronized (this) {
+            int numMembers = numMembersPerGroup.getOrDefault(groupId, 0);
+            if (numMembers == 0) {
+                // Remove the group from the map as it is empty.
+                numMembersPerGroup.remove(groupId);
+                if (shareGroupListener != null) {
+                    notify = true;
+                }
+            }
+        }
+        // Notify outside the synchronized block to avoid potential deadlocks.
+        if (notify) {
+            shareGroupListener.onGroupEmpty(groupId);
+        }
+    }
+
     // Visible for testing.
     Meter evictionsMeter() {
         return evictionsMeter;
@@ -229,18 +270,56 @@ public class ShareSessionCache {
         return numMembersPerGroup.get(groupId);
     }
 
+    // Visible for testing.
+    synchronized SessionKeyAndState connectionSessionKeyAndState(String 
connectionId) {
+        return connectionIdToSessionMap.get(connectionId);
+    }
+
     private final class ClientConnectionDisconnectListener implements 
ConnectionDisconnectListener {
 
         // When the client disconnects, the corresponding session should be 
removed from the cache.
         @Override
         public void onDisconnect(String connectionId) {
-            ShareSessionKey shareSessionKey = 
connectionIdToSessionMap.remove(connectionId);
-            if (shareSessionKey != null) {
-                // Try removing session and notify listeners. The session 
might already be removed
-                // as part of final epoch from client, so we need to check if 
the session is still
-                // present in the cache.
-                maybeRemoveAndNotifyListeners(shareSessionKey);
+            SessionKeyAndState sessionKeyAndState = 
maybeRemoveConnectionFromSession(connectionId);
+            if (sessionKeyAndState != null) {
+                // If the session is not stale, try removing the session and 
notify listeners.
+                if (!sessionKeyAndState.stale()) {
+                    // Try removing session and notify listeners. The session 
might already be removed
+                    // as part of final epoch from client, so we need to check 
if the session is still
+                    // present in the cache.
+                    
maybeRemoveAndNotifyListenersOnMemberLeave(sessionKeyAndState.shareSessionKey());
+                }
+                // Notify the share group listener if the group is empty. This 
should be checked regardless
+                // session is evicted by connection disconnect or client's 
final epoch.
+                
checkAndNotifyListenersOnGroupEmpty(sessionKeyAndState.shareSessionKey().groupId());
             }
         }
     }
+
+    /**
+     * The class records the session key and tracks if the session is stale. 
The session is marked stale
+     * when the session is removed from the cache prior to the client 
disconnect event.
+     */
+    // Visible for testing.
+    static class SessionKeyAndState {
+        private final ShareSessionKey shareSessionKey;
+        private boolean stale;
+
+        SessionKeyAndState(ShareSessionKey shareSessionKey) {
+            this.shareSessionKey = shareSessionKey;
+            this.stale = false;
+        }
+
+        ShareSessionKey shareSessionKey() {
+            return shareSessionKey;
+        }
+
+        boolean stale() {
+            return stale;
+        }
+
+        void markStale() {
+            this.stale = true;
+        }
+    }
 }
diff --git 
a/server/src/test/java/org/apache/kafka/server/share/session/ShareSessionCacheTest.java
 
b/server/src/test/java/org/apache/kafka/server/share/session/ShareSessionCacheTest.java
index baeb6ecbdf7..16da3016d29 100644
--- 
a/server/src/test/java/org/apache/kafka/server/share/session/ShareSessionCacheTest.java
+++ 
b/server/src/test/java/org/apache/kafka/server/share/session/ShareSessionCacheTest.java
@@ -35,6 +35,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNotNull;
 import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class ShareSessionCacheTest {
 
@@ -177,6 +178,10 @@ public class ShareSessionCacheTest {
         // Remove session and verify listener are not called as connection 
disconnect listener didn't
         // remove the session.
         cache.remove(key1);
+        // Session should be marked stale only for memberId1.
+        assertTrue(cache.connectionSessionKeyAndState("conn-1").stale());
+        assertFalse(cache.connectionSessionKeyAndState("conn-2").stale());
+
         Mockito.verify(mockListener, Mockito.times(0)).onMemberLeave(groupId, 
memberId1);
         Mockito.verify(mockListener, Mockito.times(0)).onGroupEmpty(groupId);
         // Verify member count is updated
@@ -184,6 +189,8 @@ public class ShareSessionCacheTest {
 
         // Re-create session for memberId1.
         cache.maybeCreateSession(groupId, memberId1, 
mockedSharePartitionMap(1), "conn-1");
+        // Session should not be stale now.
+        assertFalse(cache.connectionSessionKeyAndState("conn-1").stale());
         assertEquals(2, cache.numMembers(groupId));
 
         // Simulate connection disconnect for memberId1.
@@ -235,6 +242,39 @@ public class ShareSessionCacheTest {
         assertEquals(1, cache.numMembers(groupId2));
     }
 
+    @Test
+    public void testShareGroupListenerEventsOnStaleSession() {
+        ShareGroupListener mockListener = 
Mockito.mock(ShareGroupListener.class);
+        ShareSessionCache cache = new ShareSessionCache(3);
+        cache.registerShareGroupListener(mockListener);
+
+        String groupId = "grp";
+        String memberId1 = Uuid.randomUuid().toString();
+        ShareSessionKey key1 = cache.maybeCreateSession(groupId, memberId1, 
mockedSharePartitionMap(1), "conn-1");
+
+        // Verify member count is tracked
+        assertEquals(1, cache.size());
+        assertNotNull(cache.get(key1));
+        assertEquals(1, cache.numMembers(groupId));
+
+        // Remove session and verify listener are not called as connection 
disconnect listener didn't
+        // remove the session.
+        cache.remove(key1);
+        // Session should be marked stale only for memberId1.
+        assertTrue(cache.connectionSessionKeyAndState("conn-1").stale());
+        Mockito.verify(mockListener, Mockito.times(0)).onMemberLeave(groupId, 
memberId1);
+        Mockito.verify(mockListener, Mockito.times(0)).onGroupEmpty(groupId);
+        // Verify member count is updated
+        assertEquals(0, cache.numMembers(groupId));
+
+        // Simulate connection disconnect for memberId1.
+        cache.connectionDisconnectListener().onDisconnect("conn-1");
+        // Verify only group empty event is triggered. Member leave event 
should not be triggered
+        // as session was already removed and marked stale.
+        Mockito.verify(mockListener, Mockito.times(0)).onMemberLeave(groupId, 
memberId1);
+        Mockito.verify(mockListener, Mockito.times(1)).onGroupEmpty(groupId);
+    }
+
     @Test
     public void testNoShareGroupListenerRegistered() {
         ShareSessionCache cache = new ShareSessionCache(3);

Reply via email to