Repository: kafka Updated Branches: refs/heads/trunk e554dc518 -> d60f011d7
KAFKA-5920; Handle SSL handshake failures as authentication exceptions 1. Propagate `SSLException` as `SslAuthenticationException` to enable clients to report these and avoid retries 2. Updates to `SslTransportLayer` to process bytes received even if end-of-stream 3. Some tidy up of authentication handling 4. Report exceptions in SaslClientAuthenticator as AuthenticationExceptions Author: Rajini Sivaram <rajinisiva...@googlemail.com> Reviewers: Ismael Juma <ism...@juma.me.uk> Closes #3918 from rajinisivaram/KAFKA-5920-SSL-handshake-failure Project: http://git-wip-us.apache.org/repos/asf/kafka/repo Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/d60f011d Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/d60f011d Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/d60f011d Branch: refs/heads/trunk Commit: d60f011d77ce80a44b02d43bf0889a50a8797dcd Parents: e554dc5 Author: Rajini Sivaram <rajinisiva...@googlemail.com> Authored: Fri Sep 22 20:26:46 2017 +0100 Committer: Ismael Juma <ism...@juma.me.uk> Committed: Fri Sep 22 20:29:25 2017 +0100 ---------------------------------------------------------------------- .../common/errors/AuthenticationException.java | 3 + .../errors/SslAuthenticationException.java | 44 +++ .../kafka/common/network/Authenticator.java | 17 +- .../kafka/common/network/KafkaChannel.java | 31 +- .../common/network/PlaintextChannelBuilder.java | 7 - .../apache/kafka/common/network/Selector.java | 3 + .../kafka/common/network/SslChannelBuilder.java | 8 - .../kafka/common/network/SslTransportLayer.java | 296 +++++++++++------- .../kafka/common/network/TransportLayer.java | 12 +- .../authenticator/SaslClientAuthenticator.java | 31 +- .../authenticator/SaslServerAuthenticator.java | 16 +- .../kafka/common/network/NioEchoServer.java | 12 +- .../common/network/SslTransportLayerTest.java | 305 +++++++++++++------ 13 files changed, 506 insertions(+), 279 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java ---------------------------------------------------------------------- diff --git a/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java b/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java index c56ac88..f6458c6 100644 --- a/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java +++ b/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java @@ -16,6 +16,8 @@ */ package org.apache.kafka.common.errors; +import javax.net.ssl.SSLException; + /** * This exception indicates that SASL authentication has failed. * On authentication failure, clients abort the operation requested and raise one @@ -27,6 +29,7 @@ package org.apache.kafka.common.errors; * is not supported on the broker.</li> * <li>{@link IllegalSaslStateException} if an unexpected request is received on during SASL * handshake. This could be due to misconfigured security protocol.</li> + * <li>{@link SslAuthenticationException} if SSL handshake failed due to any {@link SSLException}. * </ul> */ public class AuthenticationException extends ApiException { http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java ---------------------------------------------------------------------- diff --git a/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java b/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java new file mode 100644 index 0000000..3cdbf2a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +import javax.net.ssl.SSLException; + +/** + * This exception indicates that SSL handshake has failed. See {@link #getCause()} + * for the {@link SSLException} that caused this failure. + * <p> + * SSL handshake failures in clients may indicate client authentication + * failure due to untrusted certificates if server is configured to request + * client certificates. Handshake failures could also indicate misconfigured + * security including protocol/cipher suite mismatch, server certificate + * authentication failure or server host name verification failure. + * </p> + */ +public class SslAuthenticationException extends AuthenticationException { + + private static final long serialVersionUID = 1L; + + public SslAuthenticationException(String message) { + super(message); + } + + public SslAuthenticationException(String message, Throwable cause) { + super(message, cause); + } + +} http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java ---------------------------------------------------------------------- diff --git a/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java b/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java index fa1123e..4e2e727 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java @@ -16,7 +16,7 @@ */ package org.apache.kafka.common.network; -import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.errors.AuthenticationException; import org.apache.kafka.common.security.auth.KafkaPrincipal; import java.io.Closeable; @@ -28,15 +28,14 @@ import java.io.IOException; public interface Authenticator extends Closeable { /** * Implements any authentication mechanism. Use transportLayer to read or write tokens. - * If no further authentication needs to be done returns. + * For security protocols PLAINTEXT and SSL, this is a no-op since no further authentication + * needs to be done. For SASL_PLAINTEXT and SASL_SSL, this performs the SASL authentication. + * + * @throws AuthenticationException if authentication fails due to invalid credentials or + * other security configuration errors + * @throws IOException if read/write fails due to an I/O error */ - void authenticate() throws IOException; - - /** - * Returns the first error encountered during authentication - * @return authentication error if authentication failed, Errors.NONE otherwise - */ - Errors error(); + void authenticate() throws AuthenticationException, IOException; /** * Returns Principal using PrincipalBuilder http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java ---------------------------------------------------------------------- diff --git a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java index 24cd9cf..f07035a 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java +++ b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java @@ -69,26 +69,21 @@ public class KafkaChannel { } /** - * Does handshake of transportLayer and authentication using configured authenticator + * Does handshake of transportLayer and authentication using configured authenticator. + * For SSL with client authentication enabled, {@link TransportLayer#handshake()} performs + * authentication. For SASL, authentication is performed by {@link Authenticator#authenticate()}. */ - public void prepare() throws IOException { - if (!transportLayer.ready()) - transportLayer.handshake(); - if (transportLayer.ready() && !authenticator.complete()) { - try { + public void prepare() throws AuthenticationException, IOException { + try { + if (!transportLayer.ready()) + transportLayer.handshake(); + if (transportLayer.ready() && !authenticator.complete()) authenticator.authenticate(); - } catch (AuthenticationException e) { - switch (authenticator.error()) { - case SASL_AUTHENTICATION_FAILED: - case ILLEGAL_SASL_STATE: - case UNSUPPORTED_SASL_MECHANISM: - state = new ChannelState(ChannelState.State.AUTHENTICATION_FAILED, e); - break; - default: - // Other errors are handled as network exceptions in Selector - } - throw e; - } + } catch (AuthenticationException e) { + // Clients are notified of authentication exceptions to enable operations to be terminated + // without retries. Other errors are handled as network exceptions in Selector. + state = new ChannelState(ChannelState.State.AUTHENTICATION_FAILED, e); + throw e; } if (ready()) state = ChannelState.READY; http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java ---------------------------------------------------------------------- diff --git a/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java index 95fd903..c0d1059 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java +++ b/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java @@ -18,7 +18,6 @@ package org.apache.kafka.common.network; import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.memory.MemoryPool; -import org.apache.kafka.common.protocol.Errors; import org.apache.kafka.common.security.auth.KafkaPrincipal; import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder; import org.apache.kafka.common.security.auth.PlaintextAuthenticationContext; @@ -80,12 +79,6 @@ public class PlaintextChannelBuilder implements ChannelBuilder { } @Override - public Errors error() { - // PLAINTEXT never fails authentication - return Errors.NONE; - } - - @Override public void close() { if (principalBuilder instanceof Closeable) Utils.closeQuietly((Closeable) principalBuilder, "principal builder"); http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/Selector.java ---------------------------------------------------------------------- diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java index 7977879..b753745 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java @@ -42,6 +42,7 @@ import org.apache.kafka.common.memory.MemoryPool; import org.apache.kafka.common.metrics.Measurable; import org.apache.kafka.common.metrics.MetricConfig; import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.errors.AuthenticationException; import org.apache.kafka.common.metrics.Metrics; import org.apache.kafka.common.metrics.Sensor; import org.apache.kafka.common.metrics.stats.Avg; @@ -485,6 +486,8 @@ public class Selector implements Selectable, AutoCloseable { String desc = channel.socketDescription(); if (e instanceof IOException) log.debug("Connection with {} disconnected", desc, e); + else if (e instanceof AuthenticationException) // will be logged later as error by clients + log.debug("Connection with {} disconnected due to authentication exception", desc, e); else log.warn("Unexpected error from {}; closing connection", desc, e); close(channel, true); http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java ---------------------------------------------------------------------- diff --git a/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java index 80b9e9a..9519e58 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java +++ b/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java @@ -18,7 +18,6 @@ package org.apache.kafka.common.network; import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.memory.MemoryPool; -import org.apache.kafka.common.protocol.Errors; import org.apache.kafka.common.security.auth.KafkaPrincipal; import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder; import org.apache.kafka.common.security.auth.SslAuthenticationContext; @@ -158,12 +157,5 @@ public class SslChannelBuilder implements ChannelBuilder { public boolean complete() { return true; } - - @Override - public Errors error() { - // SSL authentication failures are currently not propagated to clients - return Errors.NONE; - } - } } http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java ---------------------------------------------------------------------- diff --git a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java index 3cd0114..f5e1e70 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java +++ b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java @@ -34,6 +34,7 @@ import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLPeerUnverifiedException; +import org.apache.kafka.common.errors.SslAuthenticationException; import org.apache.kafka.common.security.auth.KafkaPrincipal; import org.apache.kafka.common.utils.Utils; import org.slf4j.Logger; @@ -44,6 +45,14 @@ import org.slf4j.LoggerFactory; */ public class SslTransportLayer implements TransportLayer { private static final Logger log = LoggerFactory.getLogger(SslTransportLayer.class); + + private enum State { + HANDSHAKE, + HANDSHAKE_FAILED, + READY, + CLOSING + } + private final String channelId; private final SSLEngine sslEngine; private final SelectionKey key; @@ -52,8 +61,8 @@ public class SslTransportLayer implements TransportLayer { private HandshakeStatus handshakeStatus; private SSLEngineResult handshakeResult; - private boolean handshakeComplete = false; - private boolean closing = false; + private State state; + private SslAuthenticationException handshakeException; private ByteBuffer netReadBuffer; private ByteBuffer netWriteBuffer; private ByteBuffer appReadBuffer; @@ -89,8 +98,7 @@ public class SslTransportLayer implements TransportLayer { netWriteBuffer.limit(0); netReadBuffer.position(0); netReadBuffer.limit(0); - handshakeComplete = false; - closing = false; + state = State.HANDSHAKE; //initiate handshake sslEngine.beginHandshake(); handshakeStatus = sslEngine.getHandshakeStatus(); @@ -98,7 +106,7 @@ public class SslTransportLayer implements TransportLayer { @Override public boolean ready() { - return handshakeComplete; + return state == State.READY; } /** @@ -141,8 +149,8 @@ public class SslTransportLayer implements TransportLayer { */ @Override public void close() throws IOException { - if (closing) return; - closing = true; + if (state == State.CLOSING) return; + state = State.CLOSING; sslEngine.closeOutbound(); try { if (isConnected()) { @@ -183,12 +191,22 @@ public class SslTransportLayer implements TransportLayer { } /** - * Flushes the buffer to the network, non blocking + * Reads available bytes from socket channel to `netReadBuffer`. + * Visible for testing. + * @return number of bytes read + */ + protected int readFromSocketChannel() throws IOException { + return socketChannel.read(netReadBuffer); + } + + /** + * Flushes the buffer to the network, non blocking. + * Visible for testing. * @param buf ByteBuffer * @return boolean true if the buffer has been emptied out, false otherwise * @throws IOException */ - private boolean flush(ByteBuffer buf) throws IOException { + protected boolean flush(ByteBuffer buf) throws IOException { int remaining = buf.remaining(); if (remaining > 0) { int written = socketChannel.write(buf); @@ -217,101 +235,137 @@ public class SslTransportLayer implements TransportLayer { * | unwrap() | Finished | FINISHED | * +-------------+----------------------------------+-------------+ * - * @throws IOException + * @throws IOException if read/write fails + * @throws SslAuthenticationException if handshake fails with an {@link SSLException} */ @Override public void handshake() throws IOException { + // Reset state to support renegotiation. This can be removed if renegotiation support is removed. + if (state == State.READY) + state = State.HANDSHAKE; + + int read = 0; + try { + // Read any available bytes before attempting any writes to ensure that handshake failures + // reported by the peer are processed even if writes fail (since peer closes connection + // if handshake fails) + if (key.isReadable()) + read = readFromSocketChannel(); + + doHandshake(); + } catch (SSLException e) { + handshakeFailure(e, true); + } catch (IOException e) { + maybeThrowSslAuthenticationException(); + + // this exception could be due to a write. If there is data available to unwrap, + // process the data so that any SSLExceptions are reported + if (handshakeStatus == HandshakeStatus.NEED_UNWRAP && netReadBuffer.position() > 0) { + try { + handshakeUnwrap(false); + } catch (SSLException e1) { + handshakeFailure(e1, false); + } + } + // If we get here, this is not a handshake failure, throw the original IOException + throw e; + } + + // Read from socket failed, so throw any pending handshake exception or EOF exception. + if (read == -1) { + maybeThrowSslAuthenticationException(); + throw new EOFException("EOF during handshake, handshake status is " + handshakeStatus); + } + } + + private void doHandshake() throws IOException { boolean read = key.isReadable(); boolean write = key.isWritable(); - handshakeComplete = false; handshakeStatus = sslEngine.getHandshakeStatus(); if (!flush(netWriteBuffer)) { key.interestOps(key.interestOps() | SelectionKey.OP_WRITE); return; } - try { - switch (handshakeStatus) { - case NEED_TASK: - log.trace("SSLHandshake NEED_TASK channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", - channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); - handshakeStatus = runDelegatedTasks(); + // Throw any pending handshake exception since `netWriteBuffer` has been flushed + maybeThrowSslAuthenticationException(); + + switch (handshakeStatus) { + case NEED_TASK: + log.trace("SSLHandshake NEED_TASK channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", + channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); + handshakeStatus = runDelegatedTasks(); + break; + case NEED_WRAP: + log.trace("SSLHandshake NEED_WRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", + channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); + handshakeResult = handshakeWrap(write); + if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) { + int currentNetWriteBufferSize = netWriteBufferSize(); + netWriteBuffer.compact(); + netWriteBuffer = Utils.ensureCapacity(netWriteBuffer, currentNetWriteBufferSize); + netWriteBuffer.flip(); + if (netWriteBuffer.limit() >= currentNetWriteBufferSize) { + throw new IllegalStateException("Buffer overflow when available data size (" + netWriteBuffer.limit() + + ") >= network buffer size (" + currentNetWriteBufferSize + ")"); + } + } else if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) { + throw new IllegalStateException("Should not have received BUFFER_UNDERFLOW during handshake WRAP."); + } else if (handshakeResult.getStatus() == Status.CLOSED) { + throw new EOFException(); + } + log.trace("SSLHandshake NEED_WRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", + channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); + //if handshake status is not NEED_UNWRAP or unable to flush netWriteBuffer contents + //we will break here otherwise we can do need_unwrap in the same call. + if (handshakeStatus != HandshakeStatus.NEED_UNWRAP || !flush(netWriteBuffer)) { + key.interestOps(key.interestOps() | SelectionKey.OP_WRITE); break; - case NEED_WRAP: - log.trace("SSLHandshake NEED_WRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", - channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); - handshakeResult = handshakeWrap(write); + } + case NEED_UNWRAP: + log.trace("SSLHandshake NEED_UNWRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", + channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); + do { + handshakeResult = handshakeUnwrap(read); if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) { - int currentNetWriteBufferSize = netWriteBufferSize(); - netWriteBuffer.compact(); - netWriteBuffer = Utils.ensureCapacity(netWriteBuffer, currentNetWriteBufferSize); - netWriteBuffer.flip(); - if (netWriteBuffer.limit() >= currentNetWriteBufferSize) { - throw new IllegalStateException("Buffer overflow when available data size (" + netWriteBuffer.limit() + - ") >= network buffer size (" + currentNetWriteBufferSize + ")"); + int currentAppBufferSize = applicationBufferSize(); + appReadBuffer = Utils.ensureCapacity(appReadBuffer, currentAppBufferSize); + if (appReadBuffer.position() > currentAppBufferSize) { + throw new IllegalStateException("Buffer underflow when available data size (" + appReadBuffer.position() + + ") > packet buffer size (" + currentAppBufferSize + ")"); } - } else if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) { - throw new IllegalStateException("Should not have received BUFFER_UNDERFLOW during handshake WRAP."); - } else if (handshakeResult.getStatus() == Status.CLOSED) { - throw new EOFException(); } - log.trace("SSLHandshake NEED_WRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", - channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); - //if handshake status is not NEED_UNWRAP or unable to flush netWriteBuffer contents - //we will break here otherwise we can do need_unwrap in the same call. - if (handshakeStatus != HandshakeStatus.NEED_UNWRAP || !flush(netWriteBuffer)) { - key.interestOps(key.interestOps() | SelectionKey.OP_WRITE); - break; - } - case NEED_UNWRAP: - log.trace("SSLHandshake NEED_UNWRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", - channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); - do { - handshakeResult = handshakeUnwrap(read); - if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) { - int currentAppBufferSize = applicationBufferSize(); - appReadBuffer = Utils.ensureCapacity(appReadBuffer, currentAppBufferSize); - if (appReadBuffer.position() > currentAppBufferSize) { - throw new IllegalStateException("Buffer underflow when available data size (" + appReadBuffer.position() + - ") > packet buffer size (" + currentAppBufferSize + ")"); - } - } - } while (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW); - if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) { - int currentNetReadBufferSize = netReadBufferSize(); - netReadBuffer = Utils.ensureCapacity(netReadBuffer, currentNetReadBufferSize); - if (netReadBuffer.position() >= currentNetReadBufferSize) { - throw new IllegalStateException("Buffer underflow when there is available data"); - } - } else if (handshakeResult.getStatus() == Status.CLOSED) { - throw new EOFException("SSL handshake status CLOSED during handshake UNWRAP"); + } while (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW); + if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) { + int currentNetReadBufferSize = netReadBufferSize(); + netReadBuffer = Utils.ensureCapacity(netReadBuffer, currentNetReadBufferSize); + if (netReadBuffer.position() >= currentNetReadBufferSize) { + throw new IllegalStateException("Buffer underflow when there is available data"); } - log.trace("SSLHandshake NEED_UNWRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", - channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); - - //if handshakeStatus completed than fall-through to finished status. - //after handshake is finished there is no data left to read/write in socketChannel. - //so the selector won't invoke this channel if we don't go through the handshakeFinished here. - if (handshakeStatus != HandshakeStatus.FINISHED) { - if (handshakeStatus == HandshakeStatus.NEED_WRAP) { - key.interestOps(key.interestOps() | SelectionKey.OP_WRITE); - } else if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) { - key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE); - } - break; + } else if (handshakeResult.getStatus() == Status.CLOSED) { + throw new EOFException("SSL handshake status CLOSED during handshake UNWRAP"); + } + log.trace("SSLHandshake NEED_UNWRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", + channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); + + //if handshakeStatus completed than fall-through to finished status. + //after handshake is finished there is no data left to read/write in socketChannel. + //so the selector won't invoke this channel if we don't go through the handshakeFinished here. + if (handshakeStatus != HandshakeStatus.FINISHED) { + if (handshakeStatus == HandshakeStatus.NEED_WRAP) { + key.interestOps(key.interestOps() | SelectionKey.OP_WRITE); + } else if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) { + key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE); } - case FINISHED: - handshakeFinished(); - break; - case NOT_HANDSHAKING: - handshakeFinished(); break; - default: - throw new IllegalStateException(String.format("Unexpected status [%s]", handshakeStatus)); - } - - } catch (SSLException e) { - handshakeFailure(); - throw e; + } + case FINISHED: + handshakeFinished(); + break; + case NOT_HANDSHAKING: + handshakeFinished(); + break; + default: + throw new IllegalStateException(String.format("Unexpected status [%s]", handshakeStatus)); } } @@ -346,12 +400,12 @@ public class SslTransportLayer implements TransportLayer { // It can move from FINISHED status to NOT_HANDSHAKING after the handshake is completed. // Hence we also need to check handshakeResult.getHandshakeStatus() if the handshake finished or not if (handshakeResult.getHandshakeStatus() == HandshakeStatus.FINISHED) { - //we are complete if we have delivered the last package - handshakeComplete = !netWriteBuffer.hasRemaining(); + //we are complete if we have delivered the last packet //remove OP_WRITE if we are complete, otherwise we still have data to write - if (!handshakeComplete) + if (netWriteBuffer.hasRemaining()) key.interestOps(key.interestOps() | SelectionKey.OP_WRITE); else { + state = State.READY; key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE); SSLSession session = sslEngine.getSession(); log.debug("SSL handshake completed successfully with peerHost '{}' peerPort {} peerPrincipal '{}' cipherSuite '{}'", @@ -400,10 +454,9 @@ public class SslTransportLayer implements TransportLayer { private SSLEngineResult handshakeUnwrap(boolean doRead) throws IOException { log.trace("SSLHandshake handshakeUnwrap {}", channelId); SSLEngineResult result; - if (doRead) { - int read = socketChannel.read(netReadBuffer); - if (read == -1) throw new EOFException("EOF during handshake."); - } + int read = 0; + if (doRead) + read = readFromSocketChannel(); boolean cont; do { //prepare the buffer with the incoming data @@ -420,6 +473,11 @@ public class SslTransportLayer implements TransportLayer { log.trace("SSLHandshake handshakeUnwrap: handshakeStatus {} status {}", handshakeStatus, result.getStatus()); } while (netReadBuffer.position() != 0 && cont); + // Throw EOF exception for failed read after processing already received data + // so that handshake failures are reported correctly + if (read == -1) + throw new EOFException("EOF during handshake, handshake status is " + handshakeStatus); + return result; } @@ -429,27 +487,27 @@ public class SslTransportLayer implements TransportLayer { * * @param dst The buffer into which bytes are to be transferred * @return The number of bytes read, possible zero or -1 if the channel has reached end-of-stream + * and no more data is available * @throws IOException if some other I/O error occurs */ @Override public int read(ByteBuffer dst) throws IOException { - if (closing) return -1; - int read = 0; - if (!handshakeComplete) return read; + if (state == State.CLOSING) return -1; + else if (state != State.READY) return 0; //if we have unread decrypted data in appReadBuffer read that into dst buffer. + int read = 0; if (appReadBuffer.position() > 0) { read = readFromAppBuffer(dst); } + int netread = 0; if (dst.remaining() > 0) { netReadBuffer = Utils.ensureCapacity(netReadBuffer, netReadBufferSize()); - if (netReadBuffer.remaining() > 0) { - int netread = socketChannel.read(netReadBuffer); - if (netread == 0 && netReadBuffer.position() == 0) return read; - else if (netread < 0) throw new EOFException("EOF during read"); - } - do { + if (netReadBuffer.remaining() > 0) + netread = readFromSocketChannel(); + + while (netReadBuffer.position() > 0) { netReadBuffer.flip(); SSLEngineResult unwrapResult = sslEngine.unwrap(netReadBuffer, appReadBuffer); netReadBuffer.compact(); @@ -493,8 +551,12 @@ public class SslTransportLayer implements TransportLayer { else break; } - } while (netReadBuffer.position() != 0); + } } + // If data has been read and unwrapped, return the data even if end-of-stream, channel will be closed + // on a subsequent poll. + if (read == 0 && netread < 0) + throw new EOFException("EOF during read"); return read; } @@ -553,8 +615,8 @@ public class SslTransportLayer implements TransportLayer { @Override public int write(ByteBuffer src) throws IOException { int written = 0; - if (closing) throw new IllegalStateException("Channel is in closing state"); - if (!handshakeComplete) return written; + if (state == State.CLOSING) throw new IllegalStateException("Channel is in closing state"); + if (state != State.READY) return written; if (!flush(netWriteBuffer)) return written; @@ -662,7 +724,7 @@ public class SslTransportLayer implements TransportLayer { public void addInterestOps(int ops) { if (!key.isValid()) throw new CancelledKeyException(); - else if (!handshakeComplete) + else if (state != State.READY) throw new IllegalStateException("handshake is not completed"); key.interestOps(key.interestOps() | ops); @@ -676,7 +738,7 @@ public class SslTransportLayer implements TransportLayer { public void removeInterestOps(int ops) { if (!key.isValid()) throw new CancelledKeyException(); - else if (!handshakeComplete) + else if (state != State.READY) throw new IllegalStateException("handshake is not completed"); key.interestOps(key.interestOps() & ~ops); @@ -723,7 +785,12 @@ public class SslTransportLayer implements TransportLayer { return netReadBuffer; } - private void handshakeFailure() { + /** + * SSL exceptions are propagated as authentication failures so that clients can avoid + * retries and report the failure. If `flush` is true, exceptions are propagated after + * any pending outgoing bytes are flushed to ensure that the peer is notified of the failure. + */ + private void handshakeFailure(SSLException sslException, boolean flush) throws IOException { //Release all resources such as internal buffers that SSLEngine is managing sslEngine.closeOutbound(); try { @@ -731,6 +798,17 @@ public class SslTransportLayer implements TransportLayer { } catch (SSLException e) { log.debug("SSLEngine.closeInBound() raised an exception.", e); } + + state = State.HANDSHAKE_FAILED; + handshakeException = new SslAuthenticationException("SSL handshake failed", sslException); + if (!flush || flush(netWriteBuffer)) + throw handshakeException; + } + + // If handshake has already failed, throw the authentication exception. + private void maybeThrowSslAuthenticationException() { + if (handshakeException != null) + throw handshakeException; } @Override http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java ---------------------------------------------------------------------- diff --git a/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java b/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java index be56ad5..23f866b 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java +++ b/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java @@ -31,6 +31,7 @@ import java.nio.channels.GatheringByteChannel; import java.security.Principal; +import org.apache.kafka.common.errors.AuthenticationException; public interface TransportLayer extends ScatteringByteChannel, GatheringByteChannel { @@ -61,11 +62,14 @@ public interface TransportLayer extends ScatteringByteChannel, GatheringByteChan /** - * Performs SSL handshake hence is a no-op for the non-secure - * implementation - * @throws IOException + * This a no-op for the non-secure PLAINTEXT implementation. For SSL, this performs + * SSL handshake. The SSL handshake includes client authentication if configured using + * {@link org.apache.kafka.common.config.SslConfigsSslConfigs#SSL_CLIENT_AUTH_CONFIG}. + * @throws AuthenticationException if handshake fails due to an + * {@link javax.net.ssl.SSLExceptionSSLException}. + * @throws IOException if read or write fails with an I/O error. */ - void handshake() throws IOException; + void handshake() throws AuthenticationException, IOException; /** * Returns true if there are any pending writes http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java ---------------------------------------------------------------------- diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java index 8207a5a..d9e4f0c 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java @@ -20,8 +20,8 @@ import org.apache.kafka.clients.CommonClientConfigs; import org.apache.kafka.clients.NetworkClient; import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.config.SaslConfigs; -import org.apache.kafka.common.errors.AuthenticationException; import org.apache.kafka.common.errors.IllegalSaslStateException; +import org.apache.kafka.common.errors.SaslAuthenticationException; import org.apache.kafka.common.errors.UnsupportedSaslMechanismException; import org.apache.kafka.common.network.Authenticator; import org.apache.kafka.common.network.Mode; @@ -104,9 +104,6 @@ public class SaslClientAuthenticator implements Authenticator { private RequestHeader currentRequestHeader; // Version of SaslAuthenticate request/responses private short saslAuthenticateVersion; - // Sasl authentication error which may be one of NONE, UNSUPPORTED_SASL_MECHANISM, ILLEGAL_SASL_STATE, - // SASL_AUTHENTICATION_FAILED or NETWORK_EXCEPTION - private Errors error; public SaslClientAuthenticator(Map<String, ?> configs, String node, @@ -125,7 +122,6 @@ public class SaslClientAuthenticator implements Authenticator { this.transportLayer = transportLayer; this.configs = configs; this.saslAuthenticateVersion = DISABLE_KAFKA_SASL_AUTHENTICATE_HEADER; - this.error = Errors.NONE; try { setSaslState(handshakeRequestEnable ? SaslState.SEND_APIVERSIONS_REQUEST : SaslState.INITIAL); @@ -143,7 +139,7 @@ public class SaslClientAuthenticator implements Authenticator { saslClient = createSaslClient(); } catch (Exception e) { - throw new KafkaException("Failed to configure SaslClientAuthenticator", e); + throw new SaslAuthenticationException("Failed to configure SaslClientAuthenticator", e); } } @@ -158,7 +154,7 @@ public class SaslClientAuthenticator implements Authenticator { } }); } catch (PrivilegedActionException e) { - throw new KafkaException("Failed to create SaslClient with mechanism " + mechanism, e.getCause()); + throw new SaslAuthenticationException("Failed to create SaslClient with mechanism " + mechanism, e.getCause()); } } @@ -236,11 +232,6 @@ public class SaslClientAuthenticator implements Authenticator { } } - @Override - public Errors error() { - return error; - } - private RequestHeader nextRequestHeader(ApiKeys apiKey, short version) { String clientId = (String) configs.get(CommonClientConfigs.CLIENT_ID_CONFIG); currentRequestHeader = new RequestHeader(apiKey, version, clientId, correlationId++); @@ -345,8 +336,8 @@ public class SaslClientAuthenticator implements Authenticator { } else { SaslAuthenticateResponse response = (SaslAuthenticateResponse) receiveKafkaResponse(); if (response != null) { - this.error = response.error(); - if (this.error != Errors.NONE) { + Errors error = response.error(); + if (error != Errors.NONE) { setSaslState(SaslState.FAILED); String errMsg = response.errorMessage(); throw errMsg == null ? error.exception() : error.exception(errMsg); @@ -360,7 +351,7 @@ public class SaslClientAuthenticator implements Authenticator { private byte[] createSaslToken(final byte[] saslToken, boolean isInitial) throws SaslException { if (saslToken == null) - throw new SaslException("Error authenticating with the Kafka Broker: received a `null` saslToken."); + throw new IllegalSaslStateException("Error authenticating with the Kafka Broker: received a `null` saslToken."); try { if (isInitial && !saslClient.hasInitialResponse()) @@ -384,9 +375,9 @@ public class SaslClientAuthenticator implements Authenticator { " Users must configure FQDN of kafka brokers when authenticating using SASL and" + " `socketChannel.socket().getInetAddress().getHostName()` must match the hostname in `principal/hostname@realm`"; } - error += " Kafka Client will go to AUTH_FAILED state."; + error += " Kafka Client will go to AUTHENTICATION_FAILED state."; //Unwrap the SaslException inside `PrivilegedActionException` - throw new SaslException(error, e.getCause()); + throw new SaslAuthenticationException(error, e.getCause()); } } @@ -410,12 +401,12 @@ public class SaslClientAuthenticator implements Authenticator { } catch (SchemaException | IllegalArgumentException e) { LOG.debug("Invalid SASL mechanism response, server may be expecting only GSSAPI tokens"); setSaslState(SaslState.FAILED); - throw new AuthenticationException("Invalid SASL mechanism response", e); + throw new IllegalSaslStateException("Invalid SASL mechanism response, server may be expecting a different protocol", e); } } private void handleSaslHandshakeResponse(SaslHandshakeResponse response) { - this.error = response.error(); + Errors error = response.error(); if (error != Errors.NONE) setSaslState(SaslState.FAILED); switch (error) { @@ -428,7 +419,7 @@ public class SaslClientAuthenticator implements Authenticator { throw new IllegalSaslStateException(String.format("Unexpected handshake request with client mechanism %s, enabled mechanisms are %s", mechanism, response.enabledMechanisms())); default: - throw new AuthenticationException(String.format("Unknown error code %s, client mechanism is %s, enabled mechanisms are %s", + throw new IllegalSaslStateException(String.format("Unknown error code %s, client mechanism is %s, enabled mechanisms are %s", response.error(), mechanism, response.enabledMechanisms())); } } http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java ---------------------------------------------------------------------- diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java index 6202131..fe57d27 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java @@ -123,8 +123,6 @@ public class SaslServerAuthenticator implements Authenticator { private Send netOutBuffer; // flag indicating if sasl tokens are sent as Kafka SaslAuthenticate request/responses private boolean enableKafkaSaslAuthenticateHeaders; - // authentication error if authentication failed - private Errors error; public SaslServerAuthenticator(Map<String, ?> configs, String connectionId, @@ -144,7 +142,6 @@ public class SaslServerAuthenticator implements Authenticator { this.listenerName = listenerName; this.securityProtocol = securityProtocol; this.enableKafkaSaslAuthenticateHeaders = false; - this.error = Errors.NONE; this.transportLayer = transportLayer; @@ -288,11 +285,6 @@ public class SaslServerAuthenticator implements Authenticator { } @Override - public Errors error() { - return error; - } - - @Override public boolean complete() { return saslState == SaslState.COMPLETE; } @@ -366,13 +358,11 @@ public class SaslServerAuthenticator implements Authenticator { KafkaPrincipal.ANONYMOUS, listenerName, securityProtocol); RequestAndSize requestAndSize = requestContext.parseRequest(requestBuffer); if (apiKey != ApiKeys.SASL_AUTHENTICATE) { - this.error = Errors.ILLEGAL_SASL_STATE; IllegalSaslStateException e = new IllegalSaslStateException("Unexpected Kafka request of type " + apiKey + " during SASL authentication."); sendKafkaResponse(requestContext, requestAndSize.request.getErrorResponse(e)); throw e; } if (!apiKey.isVersionSupported(version)) { - this.error = Errors.UNSUPPORTED_VERSION; // We cannot create an error response if the request version of SaslAuthenticate is not supported // This should not normally occur since clients typically check supported versions using ApiVersionsRequest throw new UnsupportedVersionException("Version " + version + " is not supported for apiKey " + apiKey); @@ -385,8 +375,7 @@ public class SaslServerAuthenticator implements Authenticator { ByteBuffer responseBuf = responseToken == null ? EMPTY_BUFFER : ByteBuffer.wrap(responseToken); sendKafkaResponse(requestContext, new SaslAuthenticateResponse(Errors.NONE, null, responseBuf)); } catch (SaslException e) { - this.error = Errors.SASL_AUTHENTICATION_FAILED; - sendKafkaResponse(requestContext, new SaslAuthenticateResponse(this.error, + sendKafkaResponse(requestContext, new SaslAuthenticateResponse(Errors.SASL_AUTHENTICATION_FAILED, "Authentication failed due to invalid credentials with SASL mechanism " + saslMechanism)); throw e; } @@ -462,8 +451,7 @@ public class SaslServerAuthenticator implements Authenticator { return clientMechanism; } else { LOG.debug("SASL mechanism '{}' requested by client is not supported", clientMechanism); - this.error = Errors.UNSUPPORTED_SASL_MECHANISM; - sendKafkaResponse(context, new SaslHandshakeResponse(this.error, enabledMechanisms)); + sendKafkaResponse(context, new SaslHandshakeResponse(Errors.UNSUPPORTED_SASL_MECHANISM, enabledMechanisms)); throw new UnsupportedSaslMechanismException("Unsupported SASL mechanism " + clientMechanism); } } http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java ---------------------------------------------------------------------- diff --git a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java index e456d68..190fa3d 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java +++ b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java @@ -85,12 +85,14 @@ public class NioEchoServer extends Thread { acceptorThread.start(); while (serverSocketChannel.isOpen()) { selector.poll(1000); - for (SocketChannel socketChannel : newChannels) { - String id = id(socketChannel); - selector.register(id, socketChannel); - socketChannels.add(socketChannel); + synchronized (newChannels) { + for (SocketChannel socketChannel : newChannels) { + String id = id(socketChannel); + selector.register(id, socketChannel); + socketChannels.add(socketChannel); + } + newChannels.clear(); } - newChannels.clear(); List<NetworkReceive> completedReceives = selector.completedReceives(); for (NetworkReceive rcv : completedReceives) { http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java ---------------------------------------------------------------------- diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java index cffcc89..90c8cd5 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java +++ b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java @@ -48,8 +48,11 @@ import java.nio.channels.SocketChannel; import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -159,7 +162,7 @@ public class SslTransportLayerTest { InetSocketAddress addr = new InetSocketAddress("127.0.0.1", server.port()); selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); - NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state()); + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); } /** @@ -187,17 +190,13 @@ public class SslTransportLayerTest { sslClientConfigs = clientCertStores.getTrustingConfig(serverCertStores); // Create a server with endpoint validation enabled on the server SSL engine - SslChannelBuilder serverChannelBuilder = new SslChannelBuilder(Mode.SERVER) { + SslChannelBuilder serverChannelBuilder = new TestSslChannelBuilder(Mode.SERVER) { @Override - protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, String host) throws IOException { - SocketChannel socketChannel = (SocketChannel) key.channel(); - SSLEngine sslEngine = sslFactory.createSslEngine(host, socketChannel.socket().getPort()); + protected TestSslTransportLayer newTransportLayer(String id, SelectionKey key, SSLEngine sslEngine) throws IOException { SSLParameters sslParams = sslEngine.getSSLParameters(); sslParams.setEndpointIdentificationAlgorithm("HTTPS"); sslEngine.setSSLParameters(sslParams); - TestSslTransportLayer transportLayer = new TestSslTransportLayer(id, key, sslEngine, BUFFER_SIZE, BUFFER_SIZE, BUFFER_SIZE); - transportLayer.startHandshake(); - return transportLayer; + return super.newTransportLayer(id, key, sslEngine); } }; serverChannelBuilder.configure(sslServerConfigs); @@ -211,7 +210,7 @@ public class SslTransportLayerTest { NetworkTestUtils.checkClientConnection(selector, node, 100, 10); } - + /** * Tests that server certificate with invalid host name is not accepted by * a client that validates server endpoint. Server certificate uses @@ -230,9 +229,9 @@ public class SslTransportLayerTest { InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); - NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state()); + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); } - + /** * Tests that server certificate with invalid IP address is accepted by * a client that has disabled endpoint validation @@ -252,7 +251,7 @@ public class SslTransportLayerTest { NetworkTestUtils.checkClientConnection(selector, node, 100, 10); } - + /** * Tests that server accepts connections from clients with a trusted certificate * when client authentication is required. @@ -295,7 +294,7 @@ public class SslTransportLayerTest { sslClientConfigs.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG); createSelector(sslClientConfigs); selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); - NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state()); + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); selector.close(); server.close(); @@ -308,7 +307,7 @@ public class SslTransportLayerTest { selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); NetworkTestUtils.checkClientConnection(selector, node, 100, 10); } - + /** * Tests that server does not accept connections from clients with an untrusted certificate * when client authentication is required. @@ -323,9 +322,9 @@ public class SslTransportLayerTest { InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); - NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state()); + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); } - + /** * Tests that server does not accept connections from clients which don't * provide a certificate when client authentication is required. @@ -335,7 +334,7 @@ public class SslTransportLayerTest { String node = "0"; sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); server = createEchoServer(SecurityProtocol.SSL); - + sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG); sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG); sslClientConfigs.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG); @@ -343,9 +342,9 @@ public class SslTransportLayerTest { InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); - NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state()); + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); } - + /** * Tests that server accepts connections from a client configured * with an untrusted certificate if client authentication is disabled @@ -362,7 +361,7 @@ public class SslTransportLayerTest { NetworkTestUtils.checkClientConnection(selector, node, 100, 10); } - + /** * Tests that server accepts connections from a client that does not provide * a certificate if client authentication is disabled @@ -372,7 +371,7 @@ public class SslTransportLayerTest { String node = "0"; sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "none"); server = createEchoServer(SecurityProtocol.SSL); - + sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG); sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG); sslClientConfigs.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG); @@ -382,7 +381,7 @@ public class SslTransportLayerTest { NetworkTestUtils.checkClientConnection(selector, node, 100, 10); } - + /** * Tests that server accepts connections from a client configured * with a valid certificate if client authentication is requested @@ -398,7 +397,7 @@ public class SslTransportLayerTest { NetworkTestUtils.checkClientConnection(selector, node, 100, 10); } - + /** * Tests that server accepts connections from a client that does not provide * a certificate if client authentication is requested but not required @@ -408,7 +407,7 @@ public class SslTransportLayerTest { String node = "0"; sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "requested"); server = createEchoServer(SecurityProtocol.SSL); - + sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG); sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG); sslClientConfigs.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG); @@ -448,7 +447,7 @@ public class SslTransportLayerTest { // Expected exception } } - + /** * Tests that channels cannot be created if keystore cannot be loaded */ @@ -481,7 +480,7 @@ public class SslTransportLayerTest { NetworkTestUtils.checkClientConnection(selector, node, 100, 10); } - + /** * Tests that client connections cannot be created to a server * if key password is invalid @@ -495,9 +494,9 @@ public class SslTransportLayerTest { InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); - NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state()); + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); } - + /** * Tests that connections cannot be made with unsupported TLS versions */ @@ -506,15 +505,15 @@ public class SslTransportLayerTest { String node = "0"; sslServerConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Arrays.asList("TLSv1.2")); server = createEchoServer(SecurityProtocol.SSL); - + sslClientConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Arrays.asList("TLSv1.1")); createSelector(sslClientConfigs); InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); - NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state()); + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); } - + /** * Tests that connections cannot be made with unsupported TLS cipher suites */ @@ -524,13 +523,13 @@ public class SslTransportLayerTest { String[] cipherSuites = SSLContext.getDefault().getDefaultSSLParameters().getCipherSuites(); sslServerConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Arrays.asList(cipherSuites[0])); server = createEchoServer(SecurityProtocol.SSL); - + sslClientConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Arrays.asList(cipherSuites[1])); createSelector(sslClientConfigs); InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); - NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state()); + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); } /** @@ -546,7 +545,7 @@ public class SslTransportLayerTest { NetworkTestUtils.checkClientConnection(selector, node, 64000, 10); } - + /** * Tests handling of BUFFER_OVERFLOW during wrap when network write buffer is smaller than SSL session packet buffer size. */ @@ -602,14 +601,98 @@ public class SslTransportLayerTest { } assertTrue("Send time not recorded", channel.getAndResetNetworkThreadTimeNanos() > 0); assertEquals("Time not reset", 0, channel.getAndResetNetworkThreadTimeNanos()); + assertFalse("Unexpected bytes buffered", channel.hasBytesBuffered()); + assertEquals(0, selector.completedReceives().size()); selector.unmute(node); while (selector.completedReceives().isEmpty()) { selector.poll(100L); + assertEquals(0, selector.numStagedReceives(channel)); } assertTrue("Receive time not recorded", channel.getAndResetNetworkThreadTimeNanos() > 0); } + /** + * Tests that IOExceptions from read during SSL handshake are not treated as authentication failures. + */ + @Test + public void testIOExceptionsDuringHandshakeRead() throws Exception { + testIOExceptionsDuringHandshake(true, false); + } + + /** + * Tests that IOExceptions from write during SSL handshake are not treated as authentication failures. + */ + @Test + public void testIOExceptionsDuringHandshakeWrite() throws Exception { + testIOExceptionsDuringHandshake(false, true); + } + + private void testIOExceptionsDuringHandshake(boolean failRead, boolean failWrite) throws Exception { + server = createEchoServer(SecurityProtocol.SSL); + TestSslChannelBuilder channelBuilder = new TestSslChannelBuilder(Mode.CLIENT); + boolean done = false; + for (int i = 1; i <= 100; i++) { + int readFailureIndex = failRead ? i : Integer.MAX_VALUE; + int flushFailureIndex = failWrite ? i : Integer.MAX_VALUE; + String node = String.valueOf(i); + + channelBuilder.readFailureIndex = readFailureIndex; + channelBuilder.flushFailureIndex = flushFailureIndex; + channelBuilder.configure(sslClientConfigs); + this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder, new LogContext()); + + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + for (int j = 0; j < 30; j++) { + selector.poll(1000L); + KafkaChannel channel = selector.channel(node); + if (channel != null && channel.ready()) { + done = true; + break; + } + if (selector.disconnected().containsKey(node)) { + assertEquals(ChannelState.State.AUTHENTICATE, selector.disconnected().get(node).state()); + break; + } + } + KafkaChannel channel = selector.channel(node); + if (channel != null) + assertTrue("Channel not ready or disconnected:" + channel.state().state(), channel.ready()); + } + assertTrue("Too many invocations of read/write during SslTransportLayer.handshake()", done); + } + + /** + * Tests that handshake failures are propagated only after writes complete, even when + * there are delays in writes to ensure that clients see an authentication exception + * rather than a connection failure. + */ + @Test + public void testPeerNotifiedOfHandshakeFailure() throws Exception { + sslServerConfigs = serverCertStores.getUntrustingConfig(); + sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); + + // Test without delay and a couple of delay counts to ensure delay applies to handshake failure + for (int i = 0; i < 3; i++) { + String node = "0"; + TestSslChannelBuilder serverChannelBuilder = new TestSslChannelBuilder(Mode.SERVER); + serverChannelBuilder.configure(sslServerConfigs); + serverChannelBuilder.flushDelayCount = i; + server = new NioEchoServer(ListenerName.forSecurityProtocol(SecurityProtocol.SSL), + SecurityProtocol.SSL, new TestSecurityConfig(sslServerConfigs), + "localhost", serverChannelBuilder, null); + server.start(); + createSelector(sslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); + server.close(); + selector.close(); + } + } + @Test public void testCloseSsl() throws Exception { testClose(SecurityProtocol.SSL, new SslChannelBuilder(Mode.CLIENT)); @@ -654,24 +737,13 @@ public class SslTransportLayerTest { private void createSelector(Map<String, Object> sslClientConfigs) { createSelector(sslClientConfigs, null, null, null); - } + } private void createSelector(Map<String, Object> sslClientConfigs, final Integer netReadBufSize, final Integer netWriteBufSize, final Integer appBufSize) { - - this.channelBuilder = new SslChannelBuilder(Mode.CLIENT) { - - @Override - protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, String host) throws IOException { - SocketChannel socketChannel = (SocketChannel) key.channel(); - SSLEngine sslEngine = sslFactory.createSslEngine(host, socketChannel.socket().getPort()); - TestSslTransportLayer transportLayer = new TestSslTransportLayer(id, key, sslEngine, netReadBufSize, netWriteBufSize, appBufSize); - transportLayer.startHandshake(); - return transportLayer; - } - - - }; + TestSslChannelBuilder channelBuilder = new TestSslChannelBuilder(Mode.CLIENT); + channelBuilder.configureBufferSizes(netReadBufSize, netWriteBufSize, appBufSize); + this.channelBuilder = channelBuilder; this.channelBuilder.configure(sslClientConfigs); this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder, new LogContext()); } @@ -683,47 +755,111 @@ public class SslTransportLayerTest { private NioEchoServer createEchoServer(SecurityProtocol securityProtocol) throws Exception { return createEchoServer(ListenerName.forSecurityProtocol(securityProtocol), securityProtocol); } - - /** - * SSLTransportLayer with overrides for packet and application buffer size to test buffer resize - * code path. The overridden buffer size starts with a small value and increases in size when the buffer - * size is retrieved to handle overflow/underflow, until the actual session buffer size is reached. - */ - private static class TestSslTransportLayer extends SslTransportLayer { - - private final ResizeableBufferSize netReadBufSize; - private final ResizeableBufferSize netWriteBufSize; - private final ResizeableBufferSize appBufSize; - - public TestSslTransportLayer(String channelId, SelectionKey key, SSLEngine sslEngine, - Integer netReadBufSize, Integer netWriteBufSize, Integer appBufSize) throws IOException { - super(channelId, key, sslEngine, false); - this.netReadBufSize = new ResizeableBufferSize(netReadBufSize); - this.netWriteBufSize = new ResizeableBufferSize(netWriteBufSize); - this.appBufSize = new ResizeableBufferSize(appBufSize); + + private static class TestSslChannelBuilder extends SslChannelBuilder { + + private Integer netReadBufSizeOverride; + private Integer netWriteBufSizeOverride; + private Integer appBufSizeOverride; + long readFailureIndex = Long.MAX_VALUE; + long flushFailureIndex = Long.MAX_VALUE; + int flushDelayCount = 0; + + public TestSslChannelBuilder(Mode mode) { + super(mode); } - - @Override - protected int netReadBufferSize() { - ByteBuffer netReadBuffer = netReadBuffer(); - // netReadBufferSize() is invoked in SSLTransportLayer.read() prior to the read - // operation. To avoid the read buffer being expanded too early, increase buffer size - // only when read buffer is full. This ensures that BUFFER_UNDERFLOW is always - // triggered in testNetReadBufferResize(). - boolean updateBufSize = netReadBuffer != null && !netReadBuffer().hasRemaining(); - return netReadBufSize.updateAndGet(super.netReadBufferSize(), updateBufSize); + + public void configureBufferSizes(Integer netReadBufSize, Integer netWriteBufSize, Integer appBufSize) { + this.netReadBufSizeOverride = netReadBufSize; + this.netWriteBufSizeOverride = netWriteBufSize; + this.appBufSizeOverride = appBufSize; } - + @Override - protected int netWriteBufferSize() { - return netWriteBufSize.updateAndGet(super.netWriteBufferSize(), true); + protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, String host) throws IOException { + SocketChannel socketChannel = (SocketChannel) key.channel(); + SSLEngine sslEngine = sslFactory.createSslEngine(host, socketChannel.socket().getPort()); + TestSslTransportLayer transportLayer = newTransportLayer(id, key, sslEngine); + transportLayer.startHandshake(); + return transportLayer; } - @Override - protected int applicationBufferSize() { - return appBufSize.updateAndGet(super.applicationBufferSize(), true); + protected TestSslTransportLayer newTransportLayer(String id, SelectionKey key, SSLEngine sslEngine) throws IOException { + return new TestSslTransportLayer(id, key, sslEngine); + } + + /** + * SSLTransportLayer with overrides for testing including: + * <ul> + * <li>Overrides for packet and application buffer size to test buffer resize code path. + * The overridden buffer size starts with a small value and increases in size when the buffer size + * is retrieved to handle overflow/underflow, until the actual session buffer size is reached.</li> + * <li>IOException injection for reads and writes for testing exception handling during handshakes.</li> + * <li>Delayed writes to test handshake failure notifications to peer</li> + * </ul> + */ + class TestSslTransportLayer extends SslTransportLayer { + + private final ResizeableBufferSize netReadBufSize; + private final ResizeableBufferSize netWriteBufSize; + private final ResizeableBufferSize appBufSize; + private final AtomicLong numReadsRemaining; + private final AtomicLong numFlushesRemaining; + private final AtomicInteger numDelayedFlushesRemaining; + + public TestSslTransportLayer(String channelId, SelectionKey key, SSLEngine sslEngine) throws IOException { + super(channelId, key, sslEngine, false); + this.netReadBufSize = new ResizeableBufferSize(netReadBufSizeOverride); + this.netWriteBufSize = new ResizeableBufferSize(netWriteBufSizeOverride); + this.appBufSize = new ResizeableBufferSize(appBufSizeOverride); + numReadsRemaining = new AtomicLong(readFailureIndex); + numFlushesRemaining = new AtomicLong(flushFailureIndex); + numDelayedFlushesRemaining = new AtomicInteger(flushDelayCount); + } + + @Override + protected int netReadBufferSize() { + ByteBuffer netReadBuffer = netReadBuffer(); + // netReadBufferSize() is invoked in SSLTransportLayer.read() prior to the read + // operation. To avoid the read buffer being expanded too early, increase buffer size + // only when read buffer is full. This ensures that BUFFER_UNDERFLOW is always + // triggered in testNetReadBufferResize(). + boolean updateBufSize = netReadBuffer != null && !netReadBuffer().hasRemaining(); + return netReadBufSize.updateAndGet(super.netReadBufferSize(), updateBufSize); + } + + @Override + protected int netWriteBufferSize() { + return netWriteBufSize.updateAndGet(super.netWriteBufferSize(), true); + } + + @Override + protected int applicationBufferSize() { + return appBufSize.updateAndGet(super.applicationBufferSize(), true); + } + + @Override + protected int readFromSocketChannel() throws IOException { + if (numReadsRemaining.decrementAndGet() == 0 && !ready()) + throw new IOException("Test exception during read"); + return super.readFromSocketChannel(); + } + + @Override + protected boolean flush(ByteBuffer buf) throws IOException { + if (numFlushesRemaining.decrementAndGet() == 0 && !ready()) + throw new IOException("Test exception during write"); + else if (numDelayedFlushesRemaining.getAndDecrement() != 0) + return false; + resetDelayedFlush(); + return super.flush(buf); + } + + private void resetDelayedFlush() { + numDelayedFlushesRemaining.set(flushDelayCount); + } } - + private static class ResizeableBufferSize { private Integer bufSizeOverride; ResizeableBufferSize(Integer bufSizeOverride) { @@ -740,5 +876,4 @@ public class SslTransportLayerTest { } } } - }