Github user anmolnar commented on a diff in the pull request:
https://github.com/apache/zookeeper/pull/679#discussion_r233647195
--- Diff:
zookeeper-server/src/test/java/org/apache/zookeeper/server/quorum/UnifiedServerSocketTest.java
---
@@ -17,156 +17,644 @@
*/
package org.apache.zookeeper.server.quorum;
+import java.io.BufferedInputStream;
+import java.io.IOException;
+import java.net.ConnectException;
+import java.net.InetSocketAddress;
+import java.net.Socket;
+import java.net.SocketException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Random;
+
+import javax.net.ssl.HandshakeCompletedEvent;
+import javax.net.ssl.HandshakeCompletedListener;
+import javax.net.ssl.SSLSocket;
+
import org.apache.zookeeper.PortAssignment;
import org.apache.zookeeper.client.ZKClientConfig;
+import org.apache.zookeeper.common.BaseX509ParameterizedTestCase;
import org.apache.zookeeper.common.ClientX509Util;
-import org.apache.zookeeper.common.Time;
+import org.apache.zookeeper.common.KeyStoreFileType;
+import org.apache.zookeeper.common.X509Exception;
+import org.apache.zookeeper.common.X509KeyType;
+import org.apache.zookeeper.common.X509TestContext;
import org.apache.zookeeper.common.X509Util;
import org.apache.zookeeper.server.ServerCnxnFactory;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
-import javax.net.ssl.HandshakeCompletedEvent;
-import javax.net.ssl.HandshakeCompletedListener;
-import javax.net.ssl.SSLSocket;
-import java.io.IOException;
-import java.net.ConnectException;
-import java.net.InetSocketAddress;
-import java.net.Socket;
-
-import static org.hamcrest.CoreMatchers.equalTo;
-import static org.junit.Assert.assertThat;
+@RunWith(Parameterized.class)
+public class UnifiedServerSocketTest extends BaseX509ParameterizedTestCase
{
-public class UnifiedServerSocketTest {
+ @Parameterized.Parameters
+ public static Collection<Object[]> params() {
+ ArrayList<Object[]> result = new ArrayList<>();
+ int paramIndex = 0;
+ for (X509KeyType caKeyType : X509KeyType.values()) {
+ for (X509KeyType certKeyType : X509KeyType.values()) {
+ for (Boolean hostnameVerification : new Boolean[] { true,
false }) {
+ result.add(new Object[]{
+ caKeyType,
+ certKeyType,
+ hostnameVerification,
+ paramIndex++
+ });
+ }
+ }
+ }
+ return result;
+ }
private static final int MAX_RETRIES = 5;
private static final int TIMEOUT = 1000;
+ private static final byte[] DATA_TO_CLIENT = "hello client".getBytes();
+ private static final byte[] DATA_FROM_CLIENT = "hello
server".getBytes();
private X509Util x509Util;
private int port;
- private volatile boolean handshakeCompleted;
+ private InetSocketAddress localServerAddress;
+ private final Object handshakeCompletedLock = new Object();
+ // access only inside synchronized(handshakeCompletedLock) { ... }
blocks
+ private boolean handshakeCompleted = false;
+
+ public UnifiedServerSocketTest(
+ final X509KeyType caKeyType,
+ final X509KeyType certKeyType,
+ final Boolean hostnameVerification,
+ final Integer paramIndex) {
+ super(paramIndex, () -> {
+ try {
+ return X509TestContext.newBuilder()
+ .setTempDir(tempDir)
+ .setKeyStoreKeyType(certKeyType)
+ .setTrustStoreKeyType(caKeyType)
+ .setHostnameVerification(hostnameVerification)
+ .build();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
@Before
public void setUp() throws Exception {
- handshakeCompleted = false;
-
port = PortAssignment.unique();
+ localServerAddress = new InetSocketAddress("localhost", port);
- String testDataPath = System.getProperty("test.data.dir",
"build/test/data");
System.setProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY,
"org.apache.zookeeper.server.NettyServerCnxnFactory");
System.setProperty(ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET,
"org.apache.zookeeper.ClientCnxnSocketNetty");
System.setProperty(ZKClientConfig.SECURE_CLIENT, "true");
x509Util = new ClientX509Util();
- System.setProperty(x509Util.getSslKeystoreLocationProperty(),
testDataPath + "/ssl/testKeyStore.jks");
- System.setProperty(x509Util.getSslKeystorePasswdProperty(),
"testpass");
- System.setProperty(x509Util.getSslTruststoreLocationProperty(),
testDataPath + "/ssl/testTrustStore.jks");
- System.setProperty(x509Util.getSslTruststorePasswdProperty(),
"testpass");
-
System.setProperty(x509Util.getSslHostnameVerificationEnabledProperty(),
"false");
+ x509TestContext.setSystemProperties(x509Util,
KeyStoreFileType.JKS, KeyStoreFileType.JKS);
}
- @Test
- public void testConnectWithSSL() throws Exception {
- class ServerThread extends Thread {
- public void run() {
- try {
- Socket unifiedSocket = new
UnifiedServerSocket(x509Util, port).accept();
- ((SSLSocket)unifiedSocket).getSession(); // block
until handshake completes
- } catch (IOException e) {
- e.printStackTrace();
+ private static void forceClose(java.io.Closeable s) {
+ if (s == null) {
+ return;
+ }
+ try {
+ s.close();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ private static final class UnifiedServerThread extends Thread {
+ private final byte[] dataToClient;
+ private List<byte[]> dataFromClients;
+ private List<Thread> workerThreads;
+ private UnifiedServerSocket serverSocket;
+
+ UnifiedServerThread(X509Util x509Util,
+ InetSocketAddress bindAddress,
+ boolean allowInsecureConnection,
+ byte[] dataToClient) throws IOException {
+ this.dataToClient = dataToClient;
+ dataFromClients = new ArrayList<>();
+ workerThreads = new ArrayList<>();
+ serverSocket = new UnifiedServerSocket(x509Util,
allowInsecureConnection);
+ serverSocket.bind(bindAddress);
+ }
+
+ @Override
+ public void run() {
+ try {
+ Random rnd = new Random();
+ while (true) {
+ final Socket unifiedSocket = serverSocket.accept();
+ final boolean tcpNoDelay = rnd.nextBoolean();
+ unifiedSocket.setTcpNoDelay(tcpNoDelay);
+ unifiedSocket.setSoTimeout(TIMEOUT);
+ final boolean keepAlive = rnd.nextBoolean();
+ unifiedSocket.setKeepAlive(keepAlive);
+ // Note: getting the input stream should not block the
thread or trigger mode detection.
+ BufferedInputStream bis = new
BufferedInputStream(unifiedSocket.getInputStream());
+ Thread t = new Thread(new Runnable() {
+ @Override
+ public void run() {
+ try {
+ byte[] buf = new byte[1024];
+ int bytesRead =
unifiedSocket.getInputStream().read(buf, 0, 1024);
+ // Make sure the settings applied above
before the socket was potentially upgraded to
+ // TLS still apply.
+ Assert.assertEquals(tcpNoDelay,
unifiedSocket.getTcpNoDelay());
+ Assert.assertEquals(TIMEOUT,
unifiedSocket.getSoTimeout());
+ Assert.assertEquals(keepAlive,
unifiedSocket.getKeepAlive());
+ if (bytesRead > 0) {
+ byte[] dataFromClient = new
byte[bytesRead];
+ System.arraycopy(buf, 0,
dataFromClient, 0, bytesRead);
+ synchronized (dataFromClients) {
+
dataFromClients.add(dataFromClient);
+ }
+ }
+
unifiedSocket.getOutputStream().write(dataToClient);
+ unifiedSocket.getOutputStream().flush();
+ } catch (IOException e) {
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ } finally {
+ forceClose(unifiedSocket);
+ }
+ }
+ });
+ workerThreads.add(t);
+ t.start();
}
+ } catch (IOException e) {
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ } finally {
+ forceClose(serverSocket);
}
}
- ServerThread serverThread = new ServerThread();
- serverThread.start();
+ public void shutdown(long millis) throws InterruptedException {
+ forceClose(serverSocket); // this should break the run() loop
+ for (Thread t : workerThreads) {
+ t.join(millis);
+ }
+ this.join(millis);
+ }
+
+ synchronized byte[] getDataFromClient(int index) {
+ return dataFromClients.get(index);
+ }
+ }
+
+ private SSLSocket connectWithSSL() throws IOException, X509Exception,
InterruptedException {
SSLSocket sslSocket = null;
int retries = 0;
while (retries < MAX_RETRIES) {
try {
sslSocket = x509Util.createSSLSocket();
+ sslSocket.addHandshakeCompletedListener(new
HandshakeCompletedListener() {
+ @Override
+ public void handshakeCompleted(HandshakeCompletedEvent
handshakeCompletedEvent) {
+ synchronized (handshakeCompletedLock) {
+ handshakeCompleted = true;
+ handshakeCompletedLock.notifyAll();
+ }
+ }
+ });
sslSocket.setSoTimeout(TIMEOUT);
- sslSocket.connect(new InetSocketAddress(port), TIMEOUT);
+ sslSocket.connect(localServerAddress, TIMEOUT);
break;
} catch (ConnectException connectException) {
connectException.printStackTrace();
+ forceClose(sslSocket);
+ sslSocket = null;
Thread.sleep(TIMEOUT);
}
retries++;
}
- sslSocket.addHandshakeCompletedListener(new
HandshakeCompletedListener() {
- @Override
- public void handshakeCompleted(HandshakeCompletedEvent
handshakeCompletedEvent) {
- completeHandshake();
+ Assert.assertNotNull("Failed to connect to server with SSL",
sslSocket);
+ return sslSocket;
+ }
+
+ private Socket connectWithoutSSL() throws IOException,
InterruptedException {
+ Socket socket = null;
+ int retries = 0;
+ while (retries < MAX_RETRIES) {
+ try {
+ socket = new Socket();
+ socket.setSoTimeout(TIMEOUT);
+ socket.connect(localServerAddress, TIMEOUT);
+ break;
+ } catch (ConnectException connectException) {
+ connectException.printStackTrace();
+ forceClose(socket);
+ socket = null;
+ Thread.sleep(TIMEOUT);
}
- });
- sslSocket.startHandshake();
+ retries++;
+ }
+ Assert.assertNotNull("Failed to connect to server without SSL",
socket);
+ return socket;
+ }
+
+ // In the tests below, a "Strict" server means a UnifiedServerSocket
that
+ // does not allow plaintext connections (in other words, it's
SSL-only).
+ // A "Non Strict" server means a UnifiedServerSocket that allows both
+ // plaintext and SSL incoming connections.
+
+ /**
+ * Attempting to connect to a SSL-or-plaintext server with SSL should
work.
+ */
+ @Test
+ public void testConnectWithSSLToNonStrictServer() throws Exception {
+ UnifiedServerThread serverThread = new UnifiedServerThread(
+ x509Util, localServerAddress, true, DATA_TO_CLIENT);
+ serverThread.start();
- serverThread.join(TIMEOUT);
+ Socket sslSocket = connectWithSSL();
+ sslSocket.getOutputStream().write(DATA_FROM_CLIENT);
+ sslSocket.getOutputStream().flush();
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
+ int bytesRead = sslSocket.getInputStream().read(buf, 0,
buf.length);
+ Assert.assertEquals(buf.length, bytesRead);
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
- long start = Time.currentElapsedTime();
- while (Time.currentElapsedTime() < start + TIMEOUT) {
- if (handshakeCompleted) {
- return;
+ serverThread.shutdown(TIMEOUT);
+ forceClose(sslSocket);
+
+ synchronized (handshakeCompletedLock) {
+ if (!handshakeCompleted) {
+ handshakeCompletedLock.wait(TIMEOUT);
}
+ Assert.assertTrue(handshakeCompleted);
}
+ Assert.assertArrayEquals(DATA_FROM_CLIENT,
serverThread.getDataFromClient(0));
+ }
+
+ /**
+ * Attempting to connect to a SSL-only server with SSL should work.
+ */
+ @Test
+ public void testConnectWithSSLToStrictServer() throws Exception {
+ UnifiedServerThread serverThread = new UnifiedServerThread(
+ x509Util, localServerAddress, false, DATA_TO_CLIENT);
+ serverThread.start();
+
+ Socket sslSocket = connectWithSSL();
+ sslSocket.getOutputStream().write(DATA_FROM_CLIENT);
+ sslSocket.getOutputStream().flush();
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
+ int bytesRead = sslSocket.getInputStream().read(buf, 0,
buf.length);
+ Assert.assertEquals(buf.length, bytesRead);
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
- Assert.fail("failed to complete handshake");
+ serverThread.shutdown(TIMEOUT);
+ forceClose(sslSocket);
+
+ synchronized (handshakeCompletedLock) {
+ if (!handshakeCompleted) {
+ handshakeCompletedLock.wait(TIMEOUT);
+ }
+ Assert.assertTrue(handshakeCompleted);
+ }
+
+ Assert.assertArrayEquals(DATA_FROM_CLIENT,
serverThread.getDataFromClient(0));
}
- private void completeHandshake() {
- handshakeCompleted = true;
+ /**
+ * Attempting to connect to a SSL-or-plaintext server without SSL
should work.
+ */
+ @Test
+ public void testConnectWithoutSSLToNonStrictServer() throws Exception {
+ UnifiedServerThread serverThread = new UnifiedServerThread(
+ x509Util, localServerAddress, true, DATA_TO_CLIENT);
+ serverThread.start();
+
+ Socket socket = connectWithoutSSL();
+ socket.getOutputStream().write(DATA_FROM_CLIENT);
+ socket.getOutputStream().flush();
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
+ int bytesRead = socket.getInputStream().read(buf, 0, buf.length);
+ Assert.assertEquals(buf.length, bytesRead);
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
+
+ serverThread.shutdown(TIMEOUT);
+ forceClose(socket);
+
+ Assert.assertArrayEquals(DATA_FROM_CLIENT,
serverThread.getDataFromClient(0));
}
+ /**
+ * Attempting to connect to a SSL-or-plaintext server without SSL with
a
+ * small initial data write should work. This makes sure that sending
+ * less than 5 bytes does not break the logic in the server's initial 5
+ * byte read.
+ */
@Test
- public void testConnectWithoutSSL() throws Exception {
- final byte[] testData = "hello there".getBytes();
- final String[] dataReadFromClient = {null};
+ public void testConnectWithoutSSLToNonStrictServerPartialWrite()
throws Exception {
+ UnifiedServerThread serverThread = new UnifiedServerThread(
+ x509Util, localServerAddress, true, DATA_TO_CLIENT);
+ serverThread.start();
+
+ Socket socket = connectWithoutSSL();
+ // Write only 2 bytes of the message, wait a bit, then write the
rest.
+ // This makes sure that writes smaller than 5 bytes don't break
the plaintext mode on the server
+ // once it decides that the input doesn't look like a TLS
handshake.
+ socket.getOutputStream().write(DATA_FROM_CLIENT, 0, 2);
+ socket.getOutputStream().flush();
+ Thread.sleep(TIMEOUT / 2);
+ socket.getOutputStream().write(DATA_FROM_CLIENT, 2,
DATA_FROM_CLIENT.length - 2);
+ socket.getOutputStream().flush();
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
+ int bytesRead = socket.getInputStream().read(buf, 0, buf.length);
+ Assert.assertEquals(buf.length, bytesRead);
+ Assert.assertArrayEquals(DATA_TO_CLIENT, buf);
+
+ serverThread.shutdown(TIMEOUT);
+ forceClose(socket);
- class ServerThread extends Thread {
+ Assert.assertArrayEquals(DATA_FROM_CLIENT,
serverThread.getDataFromClient(0));
+ }
+
+ /**
+ * Attempting to connect to a SSL-only server without SSL should fail.
+ */
+ @Test
+ public void testConnectWithoutSSLToStrictServer() throws Exception {
+ UnifiedServerThread serverThread = new UnifiedServerThread(
+ x509Util, localServerAddress, false, DATA_TO_CLIENT);
+ serverThread.start();
+
+ Socket socket = connectWithoutSSL();
+ socket.getOutputStream().write(DATA_FROM_CLIENT);
+ socket.getOutputStream().flush();
+ byte[] buf = new byte[DATA_TO_CLIENT.length];
+ try {
+ socket.getInputStream().read(buf, 0, buf.length);
+ } catch (SocketException e) {
+ // We expect the other end to hang up the connection
+ return;
+ } finally {
+ serverThread.shutdown(TIMEOUT);
+ forceClose(socket);
+ }
+ Assert.fail("Expected server to hang up the connection. Read from
server succeeded unexpectedly.");
+ }
+
+ /**
+ * This test makes sure that UnifiedServerSocket used properly (a
single thread accept()-ing connections and
+ * handing the resulting sockets to other threads for processing) is
not vulnerable to a simple denial-of-service
+ * attack in which a client connects and never writes any bytes. This
should not block the accepting thread, since
+ * the read to determine if the client is sending a TLS handshake or
not happens in the processing thread.
+ *
+ * This version of the test uses a non-strict server socket (i.e. it
accepts both TLS and plaintext connections).
+ */
+ @Test
+ public void testDenialOfServiceResistanceNonStrictServer() throws
Exception {
+ UnifiedServerThread serverThread = new UnifiedServerThread(
--- End diff --
In these DOS tests you're actually testing the `UnifiedServerThread`
implementation which is a test-only thing. Does it make sense or am I missing
something here?
In "reality" the accept thread is basically the `Leader` and the one which
does TLS handshake is the `LearnerHandler`.
---