This is an automated email from the ASF dual-hosted git repository.
rsivaram pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/trunk by this push:
new 5d6936a4992 KAFKA-16305: Avoid optimisation in handshakeUnwrap
(#15434)
5d6936a4992 is described below
commit 5d6936a4992b77ef68da216a7c2dbf1f8c9f909e
Author: Gaurav Narula <[email protected]>
AuthorDate: Wed Feb 28 09:37:58 2024 +0000
KAFKA-16305: Avoid optimisation in handshakeUnwrap (#15434)
Performs additional unwrap during handshake after data from client is
processed to support openssl, which needs the extra unwrap to complete
handshake.
Reviewers: Ismael Juma <[email protected]>, Rajini Sivaram
<[email protected]>
---
.../kafka/common/network/SslTransportLayer.java | 7 +--
.../common/network/SslTransportLayerTest.java | 60 ++++++++++++++++++++++
2 files changed, 64 insertions(+), 3 deletions(-)
diff --git
a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
index 904c5216a40..da80e363a95 100644
---
a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
+++
b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
@@ -498,13 +498,14 @@ public class SslTransportLayer implements TransportLayer {
}
/**
- * Perform handshake unwrap
+ * Perform handshake unwrap.
+ * Visible for testing.
* @param doRead boolean If true, read more from the socket channel
* @param ignoreHandshakeStatus If true, continue to unwrap if data
available regardless of handshake status
* @return SSLEngineResult
* @throws IOException
*/
- private SSLEngineResult handshakeUnwrap(boolean doRead, boolean
ignoreHandshakeStatus) throws IOException {
+ SSLEngineResult handshakeUnwrap(boolean doRead, boolean
ignoreHandshakeStatus) throws IOException {
log.trace("SSLHandshake handshakeUnwrap {}", channelId);
SSLEngineResult result;
int read = 0;
@@ -526,7 +527,7 @@ public class SslTransportLayer implements TransportLayer {
handshakeStatus == HandshakeStatus.NEED_UNWRAP) ||
(ignoreHandshakeStatus && netReadBuffer.position() !=
position);
log.trace("SSLHandshake handshakeUnwrap: handshakeStatus {} status
{}", handshakeStatus, result.getStatus());
- } while (netReadBuffer.position() != 0 && cont);
+ } while (cont);
// Throw EOF exception for failed read after processing already
received data
// so that handshake failures are reported correctly
diff --git
a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
index d92f4facb3c..8b00bcdb955 100644
---
a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
+++
b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
@@ -36,6 +36,7 @@ import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.test.TestSslUtils;
import org.apache.kafka.test.TestUtils;
import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
@@ -51,6 +52,7 @@ import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
+import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@@ -65,13 +67,20 @@ import java.util.stream.Stream;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLParameters;
+import javax.net.ssl.SSLSession;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
/**
* Tests for the SSL transport layer. These use a test harness that runs a
simple socket server that echos back responses.
@@ -1467,4 +1476,55 @@ public class SslTransportLayerTest {
}
}
}
+
+ /**
+ * SSLEngine implementations may transition from NEED_UNWRAP to NEED_UNWRAP
+ * even after reading all the data from the socket. This test ensures we
+ * continue unwrapping and not break early.
+ * Please refer <a
href="https://issues.apache.org/jira/browse/KAFKA-16305">KAFKA-16305</a>
+ * for more information.
+ */
+ @Test
+ public void
testHandshakeUnwrapContinuesUnwrappingOnNeedUnwrapAfterAllBytesRead() throws
IOException {
+ // Given
+ byte[] data = "ClientHello?".getBytes(StandardCharsets.UTF_8);
+
+ SSLEngine sslEngine = mock(SSLEngine.class);
+ SocketChannel socketChannel = mock(SocketChannel.class);
+ SelectionKey selectionKey = mock(SelectionKey.class);
+ when(selectionKey.channel()).thenReturn(socketChannel);
+ SSLSession sslSession = mock(SSLSession.class);
+ SslTransportLayer sslTransportLayer = new SslTransportLayer(
+ "test-channel",
+ selectionKey,
+ sslEngine,
+ mock(ChannelMetadataRegistry.class)
+ );
+
+ when(sslEngine.getSession()).thenReturn(sslSession);
+ when(sslSession.getPacketBufferSize()).thenReturn(data.length * 2);
+ sslTransportLayer.startHandshake(); // to initialize the buffers
+
+ ByteBuffer netReadBuffer = sslTransportLayer.netReadBuffer();
+ netReadBuffer.clear();
+ ByteBuffer appReadBuffer = sslTransportLayer.appReadBuffer();
+ when(socketChannel.read(any(ByteBuffer.class))).then(invocation -> {
+ ((ByteBuffer) invocation.getArgument(0)).put(data);
+ return data.length;
+ });
+
+ when(sslEngine.unwrap(netReadBuffer, appReadBuffer))
+ .thenAnswer(invocation -> {
+ netReadBuffer.flip();
+ return new SSLEngineResult(SSLEngineResult.Status.OK,
SSLEngineResult.HandshakeStatus.NEED_UNWRAP, data.length, 0);
+ }).thenReturn(new SSLEngineResult(SSLEngineResult.Status.OK,
SSLEngineResult.HandshakeStatus.NEED_WRAP, 0, 0));
+
+ // When
+ SSLEngineResult result = sslTransportLayer.handshakeUnwrap(true,
false);
+
+ // Then
+ verify(sslEngine, times(2)).unwrap(netReadBuffer, appReadBuffer);
+ assertEquals(SSLEngineResult.Status.OK, result.getStatus());
+ assertEquals(SSLEngineResult.HandshakeStatus.NEED_WRAP,
result.getHandshakeStatus());
+ }
}