This is an automated email from the ASF dual-hosted git repository.

payang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new ea771563e0b KAFKA-14604: SASL session expiration time will be 
overflowed when calculation (#18526)
ea771563e0b is described below

commit ea771563e0b5f047027ab1ffa67d976d8df26864
Author: PoAn Yang <[email protected]>
AuthorDate: Sun Aug 3 19:12:04 2025 +0800

    KAFKA-14604: SASL session expiration time will be overflowed when 
calculation (#18526)
    
    The timeout value may be overflowed if users set a large expiration
    time.
    
    ```
    sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 1000 *
    sessionLifetimeMs;
    ```
    
    Fixed it by throwing exception if the value is overflowed.
    
    Reviewers: TaiJuWu <[email protected]>, Luke Chen <[email protected]>,
     TengYao Chi <[email protected]>
    
    Signed-off-by: PoAn Yang <[email protected]>
---
 .../authenticator/SaslClientAuthenticator.java     |  2 +-
 .../authenticator/SaslServerAuthenticator.java     |  2 +-
 .../java/org/apache/kafka/common/utils/Utils.java  | 13 +++
 .../authenticator/SaslAuthenticatorTest.java       | 97 +++++++++++++++++++---
 .../authenticator/SaslServerAuthenticatorTest.java | 29 +++++++
 .../org/apache/kafka/common/utils/UtilsTest.java   |  7 ++
 6 files changed, 136 insertions(+), 14 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
 
b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
index addacd92722..25653636b40 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
@@ -690,7 +690,7 @@ public class SaslClientAuthenticator implements 
Authenticator {
                 double pctToUse = 
pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + RNG.nextDouble()
                         * 
pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously;
                 sessionLifetimeMsToUse = (long) (positiveSessionLifetimeMs * 
pctToUse);
-                clientSessionReauthenticationTimeNanos = 
authenticationEndNanos + 1000 * 1000 * sessionLifetimeMsToUse;
+                clientSessionReauthenticationTimeNanos = 
Math.addExact(authenticationEndNanos, Utils.msToNs(sessionLifetimeMsToUse));
                 log.debug(
                         "Finished {} with session expiration in {} ms and 
session re-authentication on or after {} ms",
                         authenticationOrReauthenticationText(), 
positiveSessionLifetimeMs, sessionLifetimeMsToUse);
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
 
b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
index a0dbe5b21dc..b84b5dc2abc 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
@@ -681,7 +681,7 @@ public class SaslServerAuthenticator implements 
Authenticator {
                 else
                     retvalSessionLifetimeMs = 
zeroIfNegative(Math.min(credentialExpirationMs - authenticationEndMs, 
connectionsMaxReauthMs));
 
-                sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 
1000 * retvalSessionLifetimeMs;
+                sessionExpirationTimeNanos = 
Math.addExact(authenticationEndNanos, Utils.msToNs(retvalSessionLifetimeMs));
             }
 
             if (credentialExpirationMs != null) {
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java 
b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
index 7e8b990542d..f60bc03bff9 100644
--- a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
@@ -1719,4 +1719,17 @@ public final class Utils {
     public interface ThrowingRunnable {
         void run() throws Exception;
     }
+
+    /**
+     * convert millisecond to nanosecond, or throw exception if overflow
+     * @param timeMs the time in millisecond
+     * @return the converted nanosecond
+     */
+    public static long msToNs(long timeMs) {
+        try {
+            return Math.multiplyExact(1000 * 1000, timeMs);
+        } catch (ArithmeticException e) {
+            throw new IllegalArgumentException("Cannot convert " + timeMs + " 
millisecond to nanosecond due to arithmetic overflow", e);
+        }
+    }
 }
diff --git 
a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java
 
b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java
index 0a63c05eb35..13ffba2715d 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java
@@ -158,6 +158,7 @@ public class SaslAuthenticatorTest {
     private static final long CONNECTIONS_MAX_REAUTH_MS_VALUE = 100L;
     private static final int BUFFER_SIZE = 4 * 1024;
     private static Time time = Time.SYSTEM;
+    private static boolean needLargeExpiration = false;
 
     private NioEchoServer server;
     private Selector selector;
@@ -181,6 +182,7 @@ public class SaslAuthenticatorTest {
 
     @AfterEach
     public void teardown() throws Exception {
+        needLargeExpiration = false;
         if (server != null)
             this.server.close();
         if (selector != null)
@@ -1610,6 +1612,42 @@ public class SaslAuthenticatorTest {
         server.verifyReauthenticationMetrics(0, 1);
     }
 
+    @Test
+    public void testReauthenticateWithLargeReauthValue() throws Exception {
+        // enable it, we'll get a large expiration timestamp token
+        needLargeExpiration = true;
+        String node = "0";
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
+
+        configureMechanisms(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
+            List.of(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM));
+        // set a large re-auth timeout in server side
+        
saslServerConfigs.put(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS_CONFIG, 
Long.MAX_VALUE);
+        server = createEchoServer(securityProtocol);
+
+        // set to default value for sasl login configs for initialization in 
ExpiringCredentialRefreshConfig
+        saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR, 
SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_FACTOR);
+        saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER, 
SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_JITTER);
+        
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS, 
SaslConfigs.DEFAULT_LOGIN_REFRESH_MIN_PERIOD_SECONDS);
+        saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS, 
SaslConfigs.DEFAULT_LOGIN_REFRESH_BUFFER_SECONDS);
+        saslClientConfigs.put(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, 
AlternateLoginCallbackHandler.class);
+
+        createCustomClientConnection(securityProtocol, 
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, node, true);
+
+        // channel should be not null before sasl handshake
+        assertNotNull(selector.channel(node));
+
+        TestUtils.waitForCondition(() -> {
+            selector.poll(1000);
+            // this channel should be closed due to session timeout 
calculation overflow
+            return selector.channel(node) == null;
+        }, "channel didn't close with large re-authentication value");
+
+        // ensure metrics are as expected
+        server.verifyAuthenticationMetrics(0, 0);
+        server.verifyReauthenticationMetrics(0, 0);
+    }
+
     @Test
     public void testCorrelationId() {
         SaslClientAuthenticator authenticator = new SaslClientAuthenticator(
@@ -2002,7 +2040,7 @@ public class SaslAuthenticatorTest {
         if (enableSaslAuthenticateHeader)
             createClientConnection(securityProtocol, node);
         else
-            
createClientConnectionWithoutSaslAuthenticateHeader(securityProtocol, 
saslMechanism, node);
+            createCustomClientConnection(securityProtocol, saslMechanism, 
node, false);
     }
 
     private NioEchoServer startServerApiVersionsUnsupportedByClient(final 
SecurityProtocol securityProtocol, String saslMechanism) throws Exception {
@@ -2090,15 +2128,13 @@ public class SaslAuthenticatorTest {
         return server;
     }
 
-    private void createClientConnectionWithoutSaslAuthenticateHeader(final 
SecurityProtocol securityProtocol,
-            final String saslMechanism, String node) throws Exception {
-
-        final ListenerName listenerName = 
ListenerName.forSecurityProtocol(securityProtocol);
-        final Map<String, ?> configs = Collections.emptyMap();
-        final JaasContext jaasContext = JaasContext.loadClientContext(configs);
-        final Map<String, JaasContext> jaasContexts = 
Collections.singletonMap(saslMechanism, jaasContext);
-
-        SaslChannelBuilder clientChannelBuilder = new 
SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts,
+    private SaslChannelBuilder saslChannelBuilderWithoutHeader(
+        final SecurityProtocol securityProtocol,
+        final String saslMechanism,
+        final Map<String, JaasContext> jaasContexts,
+        final ListenerName listenerName
+    ) {
+        return new SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts,
                 securityProtocol, listenerName, false, saslMechanism,
                 null, null, null, time, new LogContext(), null) {
 
@@ -2125,6 +2161,42 @@ public class SaslAuthenticatorTest {
                 };
             }
         };
+    }
+
+    private void createCustomClientConnection(
+        final SecurityProtocol securityProtocol,
+        final String saslMechanism,
+        String node,
+        boolean withSaslAuthenticateHeader
+    ) throws Exception {
+
+        final ListenerName listenerName = 
ListenerName.forSecurityProtocol(securityProtocol);
+        final Map<String, ?> configs = Collections.emptyMap();
+        final JaasContext jaasContext = JaasContext.loadClientContext(configs);
+        final Map<String, JaasContext> jaasContexts = 
Collections.singletonMap(saslMechanism, jaasContext);
+
+        SaslChannelBuilder clientChannelBuilder;
+        if (!withSaslAuthenticateHeader) {
+            clientChannelBuilder = 
saslChannelBuilderWithoutHeader(securityProtocol, saslMechanism, jaasContexts, 
listenerName);
+        } else {
+            clientChannelBuilder = new 
SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts,
+                securityProtocol, listenerName, false, saslMechanism,
+                null, null, null, time, new LogContext(), null) {
+
+                @Override
+                protected SaslClientAuthenticator 
buildClientAuthenticator(Map<String, ?> configs,
+                                                                           
AuthenticateCallbackHandler callbackHandler,
+                                                                           
String id,
+                                                                           
String serverHost,
+                                                                           
String servicePrincipal,
+                                                                           
TransportLayer transportLayer,
+                                                                           
Subject subject) {
+
+                    return new SaslClientAuthenticator(configs, 
callbackHandler, id, subject,
+                        servicePrincipal, serverHost, saslMechanism, 
transportLayer, time, new LogContext());
+                }
+            };
+        }
         clientChannelBuilder.configure(saslClientConfigs);
         this.selector = NetworkTestUtils.createSelector(clientChannelBuilder, 
time);
         InetSocketAddress addr = new InetSocketAddress("localhost", 
server.port());
@@ -2581,10 +2653,11 @@ public class SaslAuthenticatorTest {
                                 + ++numInvocations;
                         String headerJson = "{" + claimOrHeaderJsonText("alg", 
"none") + "}";
                         /*
-                         * Use a short lifetime so the background refresh 
thread replaces it before we
+                         * If we're testing large expiration scenario, use a 
large lifetime.
+                         * Otherwise, use a short lifetime so the background 
refresh thread replaces it before we
                          * re-authenticate
                          */
-                        String lifetimeSecondsValueToUse = "1";
+                        String lifetimeSecondsValueToUse = needLargeExpiration 
? String.valueOf(Long.MAX_VALUE) : "1";
                         String claimsJson;
                         try {
                             claimsJson = String.format("{%s,%s,%s}",
diff --git 
a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
 
b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
index fd15fe6ff67..76c9109bd8c 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
@@ -270,6 +270,35 @@ public class SaslServerAuthenticatorTest {
         }
     }
 
+    @Test
+    public void testSessionWontExpireWithLargeExpirationTime() throws 
IOException {
+        String mechanism = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM;
+        SaslServer saslServer = mock(SaslServer.class);
+        MockTime time = new MockTime(0, 1, 1000);
+        // set a Long.MAX_VALUE as the expiration time
+        Duration largeExpirationTime = Duration.ofMillis(Long.MAX_VALUE);
+
+        try (
+            MockedStatic<?> ignored = mockSaslServer(saslServer, mechanism, 
time, largeExpirationTime);
+            MockedStatic<?> ignored2 = mockKafkaPrincipal("[principal-type]", 
"[principal-name");
+            TransportLayer transportLayer = mockTransportLayer()
+        ) {
+
+            SaslServerAuthenticator authenticator = 
getSaslServerAuthenticatorForOAuth(mechanism, transportLayer, time, 
largeExpirationTime.toMillis());
+
+            mockRequest(saslHandshakeRequest(mechanism), transportLayer);
+            authenticator.authenticate();
+
+            when(saslServer.isComplete()).thenReturn(false).thenReturn(true);
+            mockRequest(saslAuthenticateRequest(), transportLayer);
+
+            Throwable t = assertThrows(IllegalArgumentException.class, () -> 
authenticator.authenticate());
+            assertEquals(ArithmeticException.class, t.getCause().getClass());
+            assertEquals("Cannot convert " + Long.MAX_VALUE + " millisecond to 
nanosecond due to arithmetic overflow",
+                t.getMessage());
+        }
+    }
+
     private SaslServerAuthenticator getSaslServerAuthenticatorForOAuth(String 
mechanism, TransportLayer transportLayer, Time time, Long maxReauth) {
         Map<String, ?> configs = 
Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG,
                 Collections.singletonList(mechanism));
diff --git a/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java 
b/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java
index bd8d5f73888..74518fe0f44 100755
--- a/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java
@@ -1269,6 +1269,13 @@ public class UtilsTest {
         assertEquals(expected, recorded);
     }
 
+    @Test
+    public void testMsToNs() {
+        assertEquals(1000000, Utils.msToNs(1));
+        assertEquals(0, Utils.msToNs(0));
+        assertThrows(IllegalArgumentException.class, () -> 
Utils.msToNs(Long.MAX_VALUE));
+    }
+
     private Callable<Void> recordingCallable(Map<String, Object> recordingMap, 
String success, TestException failure) {
         return () -> {
             if (success == null)

Reply via email to