This is an automated email from the ASF dual-hosted git repository.

chia7712 pushed a commit to branch 3.7
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.7 by this push:
     new 633d2f139c4  KAFKA-16305: Avoid optimisation in handshakeUnwrap 
(#15434)
633d2f139c4 is described below

commit 633d2f139c403cbbe2912d04f823d74c561dab76
Author: Gaurav Narula <gaurav_naru...@apple.com>
AuthorDate: Wed Feb 28 09:37:58 2024 +0000

     KAFKA-16305: Avoid optimisation in handshakeUnwrap (#15434)
    
    Performs additional unwrap during handshake after data from client is 
processed to support openssl, which needs the extra unwrap to complete 
handshake.
    
    Reviewers: Ismael Juma <ism...@juma.me.uk>, Rajini Sivaram 
<rajinisiva...@googlemail.com>
---
 .../kafka/common/network/SslTransportLayer.java    |  7 +--
 .../common/network/SslTransportLayerTest.java      | 55 ++++++++++++++++++++++
 2 files changed, 59 insertions(+), 3 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java 
b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
index 870b1c7f7dc..08b4b200200 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
@@ -506,13 +506,14 @@ public class SslTransportLayer implements TransportLayer {
     }
 
     /**
-     * Perform handshake unwrap
+     * Perform handshake unwrap.
+     * Visible for testing.
      * @param doRead boolean If true, read more from the socket channel
      * @param ignoreHandshakeStatus If true, continue to unwrap if data 
available regardless of handshake status
      * @return SSLEngineResult
      * @throws IOException
      */
-    private SSLEngineResult handshakeUnwrap(boolean doRead, boolean 
ignoreHandshakeStatus) throws IOException {
+    SSLEngineResult handshakeUnwrap(boolean doRead, boolean 
ignoreHandshakeStatus) throws IOException {
         log.trace("SSLHandshake handshakeUnwrap {}", channelId);
         SSLEngineResult result;
         int read = 0;
@@ -534,7 +535,7 @@ public class SslTransportLayer implements TransportLayer {
                     handshakeStatus == HandshakeStatus.NEED_UNWRAP) ||
                     (ignoreHandshakeStatus && netReadBuffer.position() != 
position);
             log.trace("SSLHandshake handshakeUnwrap: handshakeStatus {} status 
{}", handshakeStatus, result.getStatus());
-        } while (netReadBuffer.position() != 0 && cont);
+        } while (cont);
 
         // Throw EOF exception for failed read after processing already 
received data
         // so that handshake failures are reported correctly
diff --git 
a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
 
b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
index e65d53b2b7a..05ba32cd86f 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
@@ -53,6 +53,7 @@ import java.nio.ByteBuffer;
 import java.nio.channels.Channels;
 import java.nio.channels.SelectionKey;
 import java.nio.channels.SocketChannel;
+import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -67,14 +68,17 @@ import java.util.stream.Stream;
 
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLEngineResult;
 import javax.net.ssl.SSLException;
 import javax.net.ssl.SSLParameters;
+import javax.net.ssl.SSLSession;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assumptions.assumeTrue;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
@@ -1477,6 +1481,57 @@ public class SslTransportLayerTest {
         }
     }
 
+    /**
+     * SSLEngine implementations may transition from NEED_UNWRAP to NEED_UNWRAP
+     * even after reading all the data from the socket. This test ensures we
+     * continue unwrapping and not break early.
+     * Please refer <a 
href="https://issues.apache.org/jira/browse/KAFKA-16305";>KAFKA-16305</a>
+     * for more information.
+     */
+    @Test
+    public void 
testHandshakeUnwrapContinuesUnwrappingOnNeedUnwrapAfterAllBytesRead() throws 
IOException {
+        // Given
+        byte[] data = "ClientHello?".getBytes(StandardCharsets.UTF_8);
+
+        SSLEngine sslEngine = mock(SSLEngine.class);
+        SocketChannel socketChannel = mock(SocketChannel.class);
+        SelectionKey selectionKey = mock(SelectionKey.class);
+        when(selectionKey.channel()).thenReturn(socketChannel);
+        SSLSession sslSession = mock(SSLSession.class);
+        SslTransportLayer sslTransportLayer = new SslTransportLayer(
+                "test-channel",
+                selectionKey,
+                sslEngine,
+                mock(ChannelMetadataRegistry.class)
+        );
+
+        when(sslEngine.getSession()).thenReturn(sslSession);
+        when(sslSession.getPacketBufferSize()).thenReturn(data.length * 2);
+        sslTransportLayer.startHandshake(); // to initialize the buffers
+
+        ByteBuffer netReadBuffer = sslTransportLayer.netReadBuffer();
+        netReadBuffer.clear();
+        ByteBuffer appReadBuffer = sslTransportLayer.appReadBuffer();
+        when(socketChannel.read(any(ByteBuffer.class))).then(invocation -> {
+            ((ByteBuffer) invocation.getArgument(0)).put(data);
+            return data.length;
+        });
+
+        when(sslEngine.unwrap(netReadBuffer, appReadBuffer))
+                .thenAnswer(invocation -> {
+                    netReadBuffer.flip();
+                    return new SSLEngineResult(SSLEngineResult.Status.OK, 
SSLEngineResult.HandshakeStatus.NEED_UNWRAP, data.length, 0);
+                }).thenReturn(new SSLEngineResult(SSLEngineResult.Status.OK, 
SSLEngineResult.HandshakeStatus.NEED_WRAP, 0, 0));
+
+        // When
+        SSLEngineResult result = sslTransportLayer.handshakeUnwrap(true, 
false);
+
+        // Then
+        verify(sslEngine, times(2)).unwrap(netReadBuffer, appReadBuffer);
+        assertEquals(SSLEngineResult.Status.OK, result.getStatus());
+        assertEquals(SSLEngineResult.HandshakeStatus.NEED_WRAP, 
result.getHandshakeStatus());
+    }
+
     @Test
     public void testSSLEngineCloseInboundInvokedOnClose() throws IOException {
         // Given

Reply via email to