This is an automated email from the ASF dual-hosted git repository.

rsivaram pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 5d6936a4992  KAFKA-16305: Avoid optimisation in handshakeUnwrap 
(#15434)
5d6936a4992 is described below

commit 5d6936a4992b77ef68da216a7c2dbf1f8c9f909e
Author: Gaurav Narula <[email protected]>
AuthorDate: Wed Feb 28 09:37:58 2024 +0000

     KAFKA-16305: Avoid optimisation in handshakeUnwrap (#15434)
    
    Performs additional unwrap during handshake after data from client is 
processed to support openssl, which needs the extra unwrap to complete 
handshake.
    
    Reviewers: Ismael Juma <[email protected]>, Rajini Sivaram 
<[email protected]>
---
 .../kafka/common/network/SslTransportLayer.java    |  7 +--
 .../common/network/SslTransportLayerTest.java      | 60 ++++++++++++++++++++++
 2 files changed, 64 insertions(+), 3 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java 
b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
index 904c5216a40..da80e363a95 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
@@ -498,13 +498,14 @@ public class SslTransportLayer implements TransportLayer {
     }
 
     /**
-     * Perform handshake unwrap
+     * Perform handshake unwrap.
+     * Visible for testing.
      * @param doRead boolean If true, read more from the socket channel
      * @param ignoreHandshakeStatus If true, continue to unwrap if data 
available regardless of handshake status
      * @return SSLEngineResult
      * @throws IOException
      */
-    private SSLEngineResult handshakeUnwrap(boolean doRead, boolean 
ignoreHandshakeStatus) throws IOException {
+    SSLEngineResult handshakeUnwrap(boolean doRead, boolean 
ignoreHandshakeStatus) throws IOException {
         log.trace("SSLHandshake handshakeUnwrap {}", channelId);
         SSLEngineResult result;
         int read = 0;
@@ -526,7 +527,7 @@ public class SslTransportLayer implements TransportLayer {
                     handshakeStatus == HandshakeStatus.NEED_UNWRAP) ||
                     (ignoreHandshakeStatus && netReadBuffer.position() != 
position);
             log.trace("SSLHandshake handshakeUnwrap: handshakeStatus {} status 
{}", handshakeStatus, result.getStatus());
-        } while (netReadBuffer.position() != 0 && cont);
+        } while (cont);
 
         // Throw EOF exception for failed read after processing already 
received data
         // so that handshake failures are reported correctly
diff --git 
a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
 
b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
index d92f4facb3c..8b00bcdb955 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
@@ -36,6 +36,7 @@ import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.test.TestSslUtils;
 import org.apache.kafka.test.TestUtils;
 import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtensionContext;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.Arguments;
@@ -51,6 +52,7 @@ import java.nio.ByteBuffer;
 import java.nio.channels.Channels;
 import java.nio.channels.SelectionKey;
 import java.nio.channels.SocketChannel;
+import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -65,13 +67,20 @@ import java.util.stream.Stream;
 
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLEngineResult;
 import javax.net.ssl.SSLParameters;
+import javax.net.ssl.SSLSession;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assumptions.assumeTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 /**
  * Tests for the SSL transport layer. These use a test harness that runs a 
simple socket server that echos back responses.
@@ -1467,4 +1476,55 @@ public class SslTransportLayerTest {
             }
         }
     }
+
+    /**
+     * SSLEngine implementations may transition from NEED_UNWRAP to NEED_UNWRAP
+     * even after reading all the data from the socket. This test ensures we
+     * continue unwrapping and not break early.
+     * Please refer <a 
href="https://issues.apache.org/jira/browse/KAFKA-16305";>KAFKA-16305</a>
+     * for more information.
+     */
+    @Test
+    public void 
testHandshakeUnwrapContinuesUnwrappingOnNeedUnwrapAfterAllBytesRead() throws 
IOException {
+        // Given
+        byte[] data = "ClientHello?".getBytes(StandardCharsets.UTF_8);
+
+        SSLEngine sslEngine = mock(SSLEngine.class);
+        SocketChannel socketChannel = mock(SocketChannel.class);
+        SelectionKey selectionKey = mock(SelectionKey.class);
+        when(selectionKey.channel()).thenReturn(socketChannel);
+        SSLSession sslSession = mock(SSLSession.class);
+        SslTransportLayer sslTransportLayer = new SslTransportLayer(
+                "test-channel",
+                selectionKey,
+                sslEngine,
+                mock(ChannelMetadataRegistry.class)
+        );
+
+        when(sslEngine.getSession()).thenReturn(sslSession);
+        when(sslSession.getPacketBufferSize()).thenReturn(data.length * 2);
+        sslTransportLayer.startHandshake(); // to initialize the buffers
+
+        ByteBuffer netReadBuffer = sslTransportLayer.netReadBuffer();
+        netReadBuffer.clear();
+        ByteBuffer appReadBuffer = sslTransportLayer.appReadBuffer();
+        when(socketChannel.read(any(ByteBuffer.class))).then(invocation -> {
+            ((ByteBuffer) invocation.getArgument(0)).put(data);
+            return data.length;
+        });
+
+        when(sslEngine.unwrap(netReadBuffer, appReadBuffer))
+                .thenAnswer(invocation -> {
+                    netReadBuffer.flip();
+                    return new SSLEngineResult(SSLEngineResult.Status.OK, 
SSLEngineResult.HandshakeStatus.NEED_UNWRAP, data.length, 0);
+                }).thenReturn(new SSLEngineResult(SSLEngineResult.Status.OK, 
SSLEngineResult.HandshakeStatus.NEED_WRAP, 0, 0));
+
+        // When
+        SSLEngineResult result = sslTransportLayer.handshakeUnwrap(true, 
false);
+
+        // Then
+        verify(sslEngine, times(2)).unwrap(netReadBuffer, appReadBuffer);
+        assertEquals(SSLEngineResult.Status.OK, result.getStatus());
+        assertEquals(SSLEngineResult.HandshakeStatus.NEED_WRAP, 
result.getHandshakeStatus());
+    }
 }

Reply via email to