This is an automated email from the ASF dual-hosted git repository.

burcham pushed a commit to branch feature/GEODE-8419-backport
in repository https://gitbox.apache.org/repos/asf/geode.git

commit 524dd58ecc6bed33d4b1774b2180a8d47179f638
Author: Bill Burcham <bill.burc...@gmail.com>
AuthorDate: Fri Oct 2 15:22:08 2020 -0700

    GEODE-8419: backported forward
---
 ...LSocketHostNameVerificationIntegrationTest.java |   4 +-
 .../internal/net/SSLSocketIntegrationTest.java     |   5 +-
 .../apache/geode/internal/net/SocketCreator.java   | 136 +++++++++++++++++----
 .../org/apache/geode/internal/tcp/Connection.java  |   3 +-
 .../geode/internal/net/SocketCreatorJUnitTest.java |  56 +++++++++
 5 files changed, 176 insertions(+), 28 deletions(-)

diff --git 
a/geode-core/src/integrationTest/java/org/apache/geode/internal/net/SSLSocketHostNameVerificationIntegrationTest.java
 
b/geode-core/src/integrationTest/java/org/apache/geode/internal/net/SSLSocketHostNameVerificationIntegrationTest.java
index 5483457..dc7df44 100755
--- 
a/geode-core/src/integrationTest/java/org/apache/geode/internal/net/SSLSocketHostNameVerificationIntegrationTest.java
+++ 
b/geode-core/src/integrationTest/java/org/apache/geode/internal/net/SSLSocketHostNameVerificationIntegrationTest.java
@@ -168,7 +168,7 @@ public class SSLSocketHostNameVerificationIntegrationTest {
     this.clientSocket = clientChannel.socket();
 
     SSLEngine sslEngine =
-        this.socketCreator.createSSLEngine(this.localHost.getHostName(), 1234);
+        this.socketCreator.createSSLEngine(this.localHost.getHostName(), 1234, 
true);
 
     try {
       this.socketCreator.handshakeSSLSocketChannel(clientSocket.getChannel(),
@@ -200,7 +200,7 @@ public class SSLSocketHostNameVerificationIntegrationTest {
       try {
         socket = serverSocket.accept();
         SocketCreator sc = 
SocketCreatorFactory.getSocketCreatorForComponent(CLUSTER);
-        final SSLEngine sslEngine = 
sc.createSSLEngine(this.localHost.getHostName(), 1234);
+        final SSLEngine sslEngine = 
sc.createSSLEngine(this.localHost.getHostName(), 1234, false);
         engine =
             sc.handshakeSSLSocketChannel(socket.getChannel(),
                 sslEngine,
diff --git 
a/geode-core/src/integrationTest/java/org/apache/geode/internal/net/SSLSocketIntegrationTest.java
 
b/geode-core/src/integrationTest/java/org/apache/geode/internal/net/SSLSocketIntegrationTest.java
index 4e6747b..078ba33 100755
--- 
a/geode-core/src/integrationTest/java/org/apache/geode/internal/net/SSLSocketIntegrationTest.java
+++ 
b/geode-core/src/integrationTest/java/org/apache/geode/internal/net/SSLSocketIntegrationTest.java
@@ -217,7 +217,7 @@ public class SSLSocketIntegrationTest {
     clientSocket = clientChannel.socket();
     NioSslEngine engine =
         
clusterSocketCreator.handshakeSSLSocketChannel(clientSocket.getChannel(),
-            clusterSocketCreator.createSSLEngine("localhost", 1234), 0, true,
+            clusterSocketCreator.createSSLEngine("localhost", 1234, true), 0, 
true,
             ByteBuffer.allocate(65535), new BufferPool(mock(DMStats.class)));
     clientChannel.configureBlocking(true);
 
@@ -262,7 +262,8 @@ public class SSLSocketIntegrationTest {
         socket = serverSocket.accept();
         SocketCreator sc = 
SocketCreatorFactory.getSocketCreatorForComponent(CLUSTER);
         engine =
-            sc.handshakeSSLSocketChannel(socket.getChannel(), 
sc.createSSLEngine("localhost", 1234),
+            sc.handshakeSSLSocketChannel(socket.getChannel(), 
sc.createSSLEngine("localhost", 1234,
+                false),
                 timeoutMillis,
                 false,
                 ByteBuffer.allocate(65535),
diff --git 
a/geode-core/src/main/java/org/apache/geode/internal/net/SocketCreator.java 
b/geode-core/src/main/java/org/apache/geode/internal/net/SocketCreator.java
index 427e758..a232fca 100755
--- a/geode-core/src/main/java/org/apache/geode/internal/net/SocketCreator.java
+++ b/geode-core/src/main/java/org/apache/geode/internal/net/SocketCreator.java
@@ -38,6 +38,8 @@ import java.security.PrivateKey;
 import java.security.UnrecoverableKeyException;
 import java.security.cert.CertificateException;
 import java.security.cert.X509Certificate;
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
@@ -47,6 +49,8 @@ import javax.net.ServerSocketFactory;
 import javax.net.SocketFactory;
 import javax.net.ssl.KeyManager;
 import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SNIHostName;
+import javax.net.ssl.SNIServerName;
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLEngine;
 import javax.net.ssl.SSLException;
@@ -56,16 +60,19 @@ import javax.net.ssl.SSLPeerUnverifiedException;
 import javax.net.ssl.SSLProtocolException;
 import javax.net.ssl.SSLServerSocket;
 import javax.net.ssl.SSLSocket;
+import javax.net.ssl.StandardConstants;
 import javax.net.ssl.TrustManager;
 import javax.net.ssl.TrustManagerFactory;
 import javax.net.ssl.X509ExtendedKeyManager;
 
 import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.validator.routines.InetAddressValidator;
 import org.apache.logging.log4j.Logger;
 
 import org.apache.geode.GemFireConfigException;
 import org.apache.geode.SystemConnectException;
 import org.apache.geode.SystemFailure;
+import org.apache.geode.annotations.VisibleForTesting;
 import org.apache.geode.annotations.internal.MakeNotStatic;
 import org.apache.geode.cache.wan.GatewaySender;
 import org.apache.geode.cache.wan.GatewayTransportFilter;
@@ -173,6 +180,11 @@ public class SocketCreator extends TcpSocketCreatorImpl {
     initialize();
   }
 
+  @VisibleForTesting
+  SocketCreator(final SSLConfig sslConfig, SSLContext sslContext) {
+    this.sslConfig = sslConfig;
+    this.sslContext = sslContext;
+  }
 
   // -------------------------------------------------------------------------
   // Static instance accessors
@@ -687,7 +699,7 @@ public class SocketCreator extends TcpSocketCreatorImpl {
         optionalWatcher.beforeConnect(socket);
       }
       socket.connect(sockaddr, Math.max(timeout, 0));
-      configureClientSSLSocket(socket, timeout);
+      configureClientSSLSocket(socket, inetadd.getHostName(), timeout);
       return socket;
 
     } finally {
@@ -708,8 +720,79 @@ public class SocketCreator extends TcpSocketCreatorImpl {
   /**
    * Returns an SSLEngine that can be used to perform TLS handshakes and 
communication
    */
-  public SSLEngine createSSLEngine(String hostName, int port) {
-    return sslContext.createSSLEngine(hostName, port);
+  public SSLEngine createSSLEngine(String hostName, int port, final boolean 
clientSocket) {
+    SSLEngine engine = getSslContext().createSSLEngine(hostName, port);
+    configureSSLEngine(engine, hostName, port, clientSocket);
+    return engine;
+  }
+
+  @VisibleForTesting
+  void configureSSLEngine(SSLEngine engine, String hostName, int port, boolean 
clientSocket) {
+    SSLParameters parameters = engine.getSSLParameters();
+    boolean updateEngineWithParameters = false;
+    if (sslConfig.doEndpointIdentification()) {
+      // set server-names so that endpoint identification algorithms can find 
what's expected
+      if (setServerNames(parameters, hostName)) {
+        updateEngineWithParameters = true;
+      }
+    }
+
+    engine.setUseClientMode(clientSocket);
+    if (!clientSocket) {
+      engine.setNeedClientAuth(sslConfig.isRequireAuth());
+    }
+
+    if (clientSocket) {
+      if (checkAndEnableHostnameValidation(parameters)) {
+        updateEngineWithParameters = true;
+      }
+    }
+
+    String[] protocols = this.sslConfig.getProtocolsAsStringArray();
+
+    if (protocols != null && !"any".equalsIgnoreCase(protocols[0])) {
+      engine.setEnabledProtocols(protocols);
+    }
+
+    String[] ciphers = this.sslConfig.getCiphersAsStringArray();
+    if (ciphers != null && !"any".equalsIgnoreCase(ciphers[0])) {
+      engine.setEnabledCipherSuites(ciphers);
+    }
+
+    if (updateEngineWithParameters) {
+      engine.setSSLParameters(parameters);
+    }
+  }
+
+  /**
+   * returns true if the SSLParameters are altered, false if not
+   */
+  private boolean setServerNames(SSLParameters modifiedParams, String 
hostName) {
+    List<SNIServerName> oldNames = modifiedParams.getServerNames();
+    oldNames = oldNames == null ? Collections.emptyList() : oldNames;
+    final List<SNIServerName> serverNames = new ArrayList<>(oldNames);
+
+    if (serverNames.stream()
+        .mapToInt(SNIServerName::getType)
+        .anyMatch(type -> type == StandardConstants.SNI_HOST_NAME)) {
+      // we already have a SNI hostname set. Do nothing.
+      return false;
+    }
+
+    if (this.sslConfig.doEndpointIdentification()
+        && InetAddressValidator.getInstance().isValid(hostName)) {
+      // endpoint validation typically uses a hostname in the sniServer 
parameter that the handshake
+      // will compare against the subject alternative addresses in the 
server's certificate. Here
+      // we attempt to get a hostname instead of the proffered numeric address
+      try {
+        hostName = InetAddress.getByName(hostName).getCanonicalHostName();
+      } catch (UnknownHostException e) {
+        // ignore - we'll see what happens with endpoint validation using a 
numeric address...
+      }
+    }
+    serverNames.add(new SNIHostName(hostName));
+    modifiedParams.setServerNames(serverNames);
+    return true;
   }
 
   /**
@@ -735,11 +818,6 @@ public class SocketCreator extends TcpSocketCreatorImpl {
     if (!clientSocket) {
       engine.setNeedClientAuth(sslConfig.isRequireAuth());
     }
-
-    if (clientSocket) {
-      SSLParameters modifiedParams = 
checkAndEnableHostnameValidation(engine.getSSLParameters());
-      engine.setSSLParameters(modifiedParams);
-    }
     while (!socketChannel.finishConnect()) {
       try {
         Thread.sleep(50);
@@ -783,18 +861,21 @@ public class SocketCreator extends TcpSocketCreatorImpl {
     return nioSslEngine;
   }
 
-  private SSLParameters checkAndEnableHostnameValidation(SSLParameters 
sslParameters) {
+  /**
+   * @return true if the parameters have been modified by this method
+   */
+  private boolean checkAndEnableHostnameValidation(SSLParameters 
sslParameters) {
     if (sslConfig.doEndpointIdentification()) {
       sslParameters.setEndpointIdentificationAlgorithm("HTTPS");
-    } else {
-      if (!hostnameValidationDisabledLogShown) {
-        logger.info("Your SSL configuration disables hostname validation. "
-            + "ssl-endpoint-identification-enabled should be set to true when 
SSL is enabled. "
-            + "Please refer to the Apache GEODE SSL Documentation for SSL 
Property: ssl‑endpoint‑identification‑enabled");
-        hostnameValidationDisabledLogShown = true;
-      }
+      return true;
+    }
+    if (!hostnameValidationDisabledLogShown) {
+      logger.info("Your SSL configuration disables hostname validation. "
+          + "ssl-endpoint-identification-enabled should be set to true when 
SSL is enabled. "
+          + "Please refer to the Apache GEODE SSL Documentation for SSL 
Property: ssl‑endpoint‑identification‑enabled");
+      hostnameValidationDisabledLogShown = true;
     }
-    return sslParameters;
+    return false;
   }
 
   /**
@@ -875,22 +956,31 @@ public class SocketCreator extends TcpSocketCreatorImpl {
    * When a socket is accepted from a server socket, it should be passed to 
this method for SSL
    * configuration.
    */
-  private void configureClientSSLSocket(Socket socket, int timeout) throws 
IOException {
+  private void configureClientSSLSocket(Socket socket, String hostName, int 
timeout)
+      throws IOException {
     if (socket instanceof SSLSocket) {
       SSLSocket sslSocket = (SSLSocket) socket;
 
       sslSocket.setUseClientMode(true);
       sslSocket.setEnableSessionCreation(true);
 
-      SSLParameters modifiedParams =
-          checkAndEnableHostnameValidation(sslSocket.getSSLParameters());
+      SSLParameters parameters = sslSocket.getSSLParameters();
+      boolean updateSSLParameters =
+          checkAndEnableHostnameValidation(parameters);
+
+      if (setServerNames(parameters, hostName)) {
+        updateSSLParameters = true;
+      }
 
       SSLParameterExtension sslParameterExtension = 
this.sslConfig.getSSLParameterExtension();
       if (sslParameterExtension != null) {
-        modifiedParams =
-            
sslParameterExtension.modifySSLClientSocketParameters(modifiedParams);
+        parameters =
+            sslParameterExtension.modifySSLClientSocketParameters(parameters);
+      }
+
+      if (updateSSLParameters) {
+        sslSocket.setSSLParameters(parameters);
       }
-      sslSocket.setSSLParameters(modifiedParams);
 
       String[] protocols = this.sslConfig.getProtocolsAsStringArray();
 
diff --git 
a/geode-core/src/main/java/org/apache/geode/internal/tcp/Connection.java 
b/geode-core/src/main/java/org/apache/geode/internal/tcp/Connection.java
index ee9892b..83465d8 100644
--- a/geode-core/src/main/java/org/apache/geode/internal/tcp/Connection.java
+++ b/geode-core/src/main/java/org/apache/geode/internal/tcp/Connection.java
@@ -1676,7 +1676,8 @@ public class Connection implements Runnable {
     if (getConduit().useSSL() && channel != null) {
       InetSocketAddress address = (InetSocketAddress) 
channel.getRemoteAddress();
       SSLEngine engine =
-          
getConduit().getSocketCreator().createSSLEngine(address.getHostName(), 
address.getPort());
+          
getConduit().getSocketCreator().createSSLEngine(address.getHostName(), 
address.getPort(),
+              clientSocket);
 
       int packetBufferSize = engine.getSession().getPacketBufferSize();
       if (inputBuffer == null || inputBuffer.capacity() < packetBufferSize) {
diff --git 
a/geode-core/src/test/java/org/apache/geode/internal/net/SocketCreatorJUnitTest.java
 
b/geode-core/src/test/java/org/apache/geode/internal/net/SocketCreatorJUnitTest.java
index d37e043..83316a1 100644
--- 
a/geode-core/src/test/java/org/apache/geode/internal/net/SocketCreatorJUnitTest.java
+++ 
b/geode-core/src/test/java/org/apache/geode/internal/net/SocketCreatorJUnitTest.java
@@ -15,20 +15,27 @@
 package org.apache.geode.internal.net;
 
 import static 
org.apache.geode.test.util.ResourceUtils.createTempFileFromResource;
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.ArgumentMatchers.isA;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 import java.net.BindException;
 import java.net.InetAddress;
 import java.net.ServerSocket;
 import java.net.Socket;
 
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLParameters;
 import javax.net.ssl.SSLSocket;
 
 import org.junit.Test;
 import org.junit.experimental.categories.Category;
+import org.mockito.ArgumentCaptor;
 
 import org.apache.geode.internal.admin.SSLConfig;
 import org.apache.geode.test.junit.categories.MembershipTest;
@@ -98,6 +105,55 @@ public class SocketCreatorJUnitTest {
     }
   }
 
+  @Test
+  public void configureSSLEngine() {
+    SSLConfig config = new 
SSLConfig.Builder().setCiphers("someCipher").setEnabled(true)
+        
.setProtocols("someProtocol").setRequireAuth(true).setKeystore("someKeystore.jks")
+        .setAlias("someAlias").setTruststore("someTruststore.jks")
+        .setEndpointIdentificationEnabled(true).build();
+
+    SSLContext context = mock(SSLContext.class);
+    SSLParameters parameters = mock(SSLParameters.class);
+
+    SocketCreator socketCreator = new SocketCreator(config, context);
+
+    SSLEngine engine = mock(SSLEngine.class);
+    when(engine.getSSLParameters()).thenReturn(parameters);
+
+    socketCreator.configureSSLEngine(engine, "somehost", 12345, true);
+
+    verify(engine).setUseClientMode(isA(Boolean.class));
+    verify(engine).setSSLParameters(parameters);
+    verify(engine, never()).setNeedClientAuth(isA(Boolean.class));
+
+    ArgumentCaptor<String[]> stringArrayCaptor = 
ArgumentCaptor.forClass(String[].class);
+    verify(engine).setEnabledProtocols(stringArrayCaptor.capture());
+    assertThat(stringArrayCaptor.getValue()).containsExactly("someProtocol");
+    verify(engine).setEnabledCipherSuites(stringArrayCaptor.capture());
+    assertThat(stringArrayCaptor.getValue()).containsExactly("someCipher");
+  }
+
+  @Test
+  public void configureSSLEngineUsingAny() {
+    SSLConfig config = new 
SSLConfig.Builder().setCiphers("any").setEnabled(true)
+        
.setProtocols("any").setRequireAuth(true).setKeystore("someKeystore.jks")
+        .setAlias("someAlias").setTruststore("someTruststore.jks")
+        .setEndpointIdentificationEnabled(true).build();
+
+    SSLContext context = mock(SSLContext.class);
+    SSLParameters parameters = mock(SSLParameters.class);
+
+    SocketCreator socketCreator = new SocketCreator(config, context);
+
+    SSLEngine engine = mock(SSLEngine.class);
+    when(engine.getSSLParameters()).thenReturn(parameters);
+
+    socketCreator.configureSSLEngine(engine, "somehost", 12345, true);
+
+    verify(engine, never()).setEnabledCipherSuites(isA(String[].class));
+    verify(engine, never()).setEnabledProtocols(isA(String[].class));
+  }
+
   private String getSingleKeyKeystore() {
     return createTempFileFromResource(getClass(), 
"/ssl/trusted.keystore").getAbsolutePath();
   }

Reply via email to