This is an automated email from the ASF dual-hosted git repository.

ycai pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git


The following commit(s) were added to refs/heads/trunk by this push:
     new ca75ffe4d0 Mixed mode support for internode authentication during TLS 
upgrades
ca75ffe4d0 is described below

commit ca75ffe4d09a3e7b26a56345c0bdacaa284eaab7
Author: Jyothsna Konisa <jkon...@apple.com>
AuthorDate: Fri Oct 7 10:03:16 2022 -0700

    Mixed mode support for internode authentication during TLS upgrades
    
    patch by Jyothsna Konisa; reviewed by Jon Meredith, Yifan Cai for 
CASSANDRA-17923
---
 CHANGES.txt                                        |   1 +
 .../cassandra/net/InternodeConnectionUtils.java    |  11 +-
 .../apache/cassandra/net/OutboundConnection.java   |  23 ++-
 .../cassandra/net/OutboundConnectionInitiator.java |  43 ++++-
 .../async/NettyStreamingConnectionFactory.java     |  45 +++--
 test/conf/cassandra_ssl_test.truststore            | Bin 992 -> 3240 bytes
 .../test/InternodeEncryptionEnforcementTest.java   |   8 +-
 .../org/apache/cassandra/net/HandshakeTest.java    | 185 ++++++++++++++++++++-
 8 files changed, 282 insertions(+), 34 deletions(-)

diff --git a/CHANGES.txt b/CHANGES.txt
index c418c48ba9..85fdae62e8 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 4.2
+ * Mixed mode support for internode authentication during TLS upgrades 
(CASSANDRA-17923)
  * Revert Mockito downgrade from CASSANDRA-17750 (CASSANDRA-17496)
  * Add --older-than and --older-than-timestamp options for nodetool 
clearsnapshots (CASSANDRA-16860)
  * Fix "open RT bound as its last item" exception (CASSANDRA-17810)
diff --git a/src/java/org/apache/cassandra/net/InternodeConnectionUtils.java 
b/src/java/org/apache/cassandra/net/InternodeConnectionUtils.java
index 39a087960b..fd3d1bd69e 100644
--- a/src/java/org/apache/cassandra/net/InternodeConnectionUtils.java
+++ b/src/java/org/apache/cassandra/net/InternodeConnectionUtils.java
@@ -18,6 +18,7 @@
 
 package org.apache.cassandra.net;
 
+import java.nio.channels.ClosedChannelException;
 import java.security.cert.Certificate;
 import javax.net.ssl.SSLPeerUnverifiedException;
 
@@ -33,7 +34,7 @@ import io.netty.handler.ssl.SslHandler;
 /**
  * Class that contains certificate utility methods.
  */
-class InternodeConnectionUtils
+public class InternodeConnectionUtils
 {
     public static String SSL_HANDLER_NAME = "ssl";
     public static String DISCARD_HANDLER_NAME = "discard";
@@ -59,6 +60,14 @@ class InternodeConnectionUtils
         return certificates;
     }
 
+    public static boolean isSSLError(final Throwable cause)
+    {
+        return (cause instanceof ClosedChannelException)
+               && cause.getCause() == null
+               && 
cause.getStackTrace()[0].getClassName().contains("SslHandler")
+               && 
cause.getStackTrace()[0].getMethodName().contains("channelInactive");
+    }
+
     /**
      * Discard handler releases the received data silently. when internode 
authentication fails, the channel is closed,
      * but the pending buffered data may still be fired through the pipeline. 
To avoid that, authentication handler is
diff --git a/src/java/org/apache/cassandra/net/OutboundConnection.java 
b/src/java/org/apache/cassandra/net/OutboundConnection.java
index 821521bfb9..2af6d3b01d 100644
--- a/src/java/org/apache/cassandra/net/OutboundConnection.java
+++ b/src/java/org/apache/cassandra/net/OutboundConnection.java
@@ -61,6 +61,7 @@ import 
org.apache.cassandra.utils.concurrent.UncheckedInterruptedException;
 import static java.lang.Math.max;
 import static java.lang.Math.min;
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static org.apache.cassandra.net.InternodeConnectionUtils.isSSLError;
 import static org.apache.cassandra.net.MessagingService.current_version;
 import static org.apache.cassandra.net.OutboundConnectionInitiator.*;
 import static 
org.apache.cassandra.net.OutboundConnections.LARGE_MESSAGE_THRESHOLD;
@@ -1100,8 +1101,9 @@ public class OutboundConnection
 
                 if (hasPending())
                 {
+                    boolean isSSLFailure = isSSLError(cause);
                     Promise<Result<MessagingSuccess>> result = 
AsyncPromise.withExecutor(eventLoop);
-                    state = new Connecting(state.disconnected(), result, 
eventLoop.schedule(() -> attempt(result), max(100, retryRateMillis), 
MILLISECONDS));
+                    state = new Connecting(state.disconnected(), result, 
eventLoop.schedule(() -> attempt(result, isSSLFailure), max(100, 
retryRateMillis), MILLISECONDS));
                     retryRateMillis = min(1000, retryRateMillis * 2);
                 }
                 else
@@ -1189,7 +1191,7 @@ public class OutboundConnection
              *
              * Note: this should only be invoked on the event loop.
              */
-            private void attempt(Promise<Result<MessagingSuccess>> result)
+            private void attempt(Promise<Result<MessagingSuccess>> result, 
boolean sslFallbackEnabled)
             {
                 ++connectionAttempts;
 
@@ -1216,7 +1218,20 @@ public class OutboundConnection
                 // ensure we connect to the correct SSL port
                 settings = 
settings.withLegacyPortIfNecessary(messagingVersion);
 
-                initiateMessaging(eventLoop, type, settings, messagingVersion, 
result)
+                // In mixed mode operation, some nodes might be configured to 
use SSL for internode connections and
+                // others might be configured to not use SSL. When a node is 
configured in optional SSL mode, It should
+                // be able to handle SSL and Non-SSL internode connections. We 
take care of this when accepting NON-SSL
+                // connection in Inbound connection by having optional SSL 
handler for inbound connections.
+                // For outbound connections, if the authentication fails, we 
should fall back to other SSL strategies
+                // while talking to older nodes in the cluster which are 
configured to make NON-SSL connections
+                SslFallbackConnectionType[] fallBackSslFallbackConnectionTypes 
= SslFallbackConnectionType.values();
+                int index = sslFallbackEnabled && settings.withEncryption() && 
settings.encryption.getOptional() ?
+                            (int) (connectionAttempts - 1) % 
fallBackSslFallbackConnectionTypes.length : 0;
+                if (fallBackSslFallbackConnectionTypes[index] != 
SslFallbackConnectionType.SERVER_CONFIG)
+                {
+                    logger.info("ConnectionId {} is falling back to {} 
reconnect strategy for retry", id(), fallBackSslFallbackConnectionTypes[index]);
+                }
+                initiateMessaging(eventLoop, type, 
fallBackSslFallbackConnectionTypes[index], settings, messagingVersion, result)
                 .addListener(future -> {
                     if (future.isCancelled())
                         return;
@@ -1231,7 +1246,7 @@ public class OutboundConnection
             {
                 Promise<Result<MessagingSuccess>> result = 
AsyncPromise.withExecutor(eventLoop);
                 state = new Connecting(state.disconnected(), result);
-                attempt(result);
+                attempt(result, false);
                 return result;
             }
         }
diff --git a/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java 
b/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java
index 7e38dd8812..f8df49b598 100644
--- a/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java
+++ b/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java
@@ -94,15 +94,17 @@ public class OutboundConnectionInitiator<SuccessType 
extends OutboundConnectionI
     private static final Logger logger = 
LoggerFactory.getLogger(OutboundConnectionInitiator.class);
 
     private final ConnectionType type;
+    private final SslFallbackConnectionType sslConnectionType;
     private final OutboundConnectionSettings settings;
     private final int requestMessagingVersion; // for pre40 nodes
     private final Promise<Result<SuccessType>> resultPromise;
     private boolean isClosed;
 
-    private OutboundConnectionInitiator(ConnectionType type, 
OutboundConnectionSettings settings,
+    private OutboundConnectionInitiator(ConnectionType type, 
SslFallbackConnectionType sslConnectionType, OutboundConnectionSettings 
settings,
                                         int requestMessagingVersion, 
Promise<Result<SuccessType>> resultPromise)
     {
         this.type = type;
+        this.sslConnectionType = sslConnectionType;
         this.requestMessagingVersion = requestMessagingVersion;
         this.settings = settings;
         this.resultPromise = resultPromise;
@@ -115,9 +117,10 @@ public class OutboundConnectionInitiator<SuccessType 
extends OutboundConnectionI
      *
      * The returned {@code Future} is guaranteed to be completed on the 
supplied eventLoop.
      */
-    public static Future<Result<StreamingSuccess>> initiateStreaming(EventLoop 
eventLoop, OutboundConnectionSettings settings, int requestMessagingVersion)
+    public static Future<Result<StreamingSuccess>> initiateStreaming(EventLoop 
eventLoop, OutboundConnectionSettings settings,
+                                                                     
SslFallbackConnectionType sslConnectionType, int requestMessagingVersion)
     {
-        return new OutboundConnectionInitiator<StreamingSuccess>(STREAMING, 
settings, requestMessagingVersion, AsyncPromise.withExecutor(eventLoop))
+        return new OutboundConnectionInitiator<StreamingSuccess>(STREAMING, 
sslConnectionType, settings, requestMessagingVersion, 
AsyncPromise.withExecutor(eventLoop))
                .initiate(eventLoop);
     }
 
@@ -128,9 +131,10 @@ public class OutboundConnectionInitiator<SuccessType 
extends OutboundConnectionI
      *
      * The returned {@code Future} is guaranteed to be completed on the 
supplied eventLoop.
      */
-    static Future<Result<MessagingSuccess>> initiateMessaging(EventLoop 
eventLoop, ConnectionType type, OutboundConnectionSettings settings, int 
requestMessagingVersion, Promise<Result<MessagingSuccess>> result)
+    static Future<Result<MessagingSuccess>> initiateMessaging(EventLoop 
eventLoop, ConnectionType type, SslFallbackConnectionType sslConnectionType,
+                                                              
OutboundConnectionSettings settings, int requestMessagingVersion, 
Promise<Result<MessagingSuccess>> result)
     {
-        return new OutboundConnectionInitiator<>(type, settings, 
requestMessagingVersion, result)
+        return new OutboundConnectionInitiator<>(type, sslConnectionType, 
settings, requestMessagingVersion, result)
                .initiate(eventLoop);
     }
 
@@ -202,6 +206,14 @@ public class OutboundConnectionInitiator<SuccessType 
extends OutboundConnectionI
         return bootstrap;
     }
 
+    public enum SslFallbackConnectionType
+    {
+        SERVER_CONFIG, // Original configuration of the server
+        MTLS,
+        SSL,
+        NO_SSL
+    }
+
     private class Initializer extends ChannelInitializer<SocketChannel>
     {
         public void initChannel(SocketChannel channel) throws Exception
@@ -209,11 +221,10 @@ public class OutboundConnectionInitiator<SuccessType 
extends OutboundConnectionI
             ChannelPipeline pipeline = channel.pipeline();
 
             // order of handlers: ssl -> server-authentication -> logger -> 
handshakeHandler
-            if (settings.withEncryption())
+            if ((sslConnectionType == SslFallbackConnectionType.SERVER_CONFIG 
&& settings.withEncryption())
+                || sslConnectionType == SslFallbackConnectionType.SSL || 
sslConnectionType == SslFallbackConnectionType.MTLS)
             {
-                // check if we should actually encrypt this connection
-                SslContext sslContext = 
SSLFactory.getOrCreateSslContext(settings.encryption, true,
-                                                                         
ISslContextFactory.SocketType.CLIENT);
+                SslContext sslContext = getSslContext(sslConnectionType);
                 // for some reason channel.remoteAddress() will return null
                 InetAddressAndPort address = settings.to;
                 InetSocketAddress peer = 
settings.encryption.require_endpoint_verification ? new 
InetSocketAddress(address.getAddress(), address.getPort()) : null;
@@ -229,6 +240,20 @@ public class OutboundConnectionInitiator<SuccessType 
extends OutboundConnectionI
             pipeline.addLast("handshake", new Handler());
         }
 
+        private SslContext getSslContext(SslFallbackConnectionType 
connectionType) throws IOException
+        {
+            boolean requireClientAuth = false;
+            if (connectionType == SslFallbackConnectionType.MTLS || 
connectionType == SslFallbackConnectionType.SSL)
+            {
+                requireClientAuth = true;
+            }
+            else if (connectionType == SslFallbackConnectionType.SERVER_CONFIG)
+            {
+                requireClientAuth = settings.withEncryption();
+            }
+            return SSLFactory.getOrCreateSslContext(settings.encryption, 
requireClientAuth, ISslContextFactory.SocketType.CLIENT);
+        }
+
     }
 
     /**
diff --git 
a/src/java/org/apache/cassandra/streaming/async/NettyStreamingConnectionFactory.java
 
b/src/java/org/apache/cassandra/streaming/async/NettyStreamingConnectionFactory.java
index 6a57e395e4..529b396367 100644
--- 
a/src/java/org/apache/cassandra/streaming/async/NettyStreamingConnectionFactory.java
+++ 
b/src/java/org/apache/cassandra/streaming/async/NettyStreamingConnectionFactory.java
@@ -20,6 +20,9 @@ package org.apache.cassandra.streaming.async;
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
 
 import com.google.common.annotations.VisibleForTesting;
 
@@ -35,7 +38,10 @@ import org.apache.cassandra.net.OutboundConnectionSettings;
 import org.apache.cassandra.streaming.StreamingChannel;
 
 import static org.apache.cassandra.locator.InetAddressAndPort.getByAddress;
+import static org.apache.cassandra.net.InternodeConnectionUtils.isSSLError;
 import static 
org.apache.cassandra.net.OutboundConnectionInitiator.initiateStreaming;
+import static 
org.apache.cassandra.net.OutboundConnectionInitiator.SslFallbackConnectionType;
+import static 
org.apache.cassandra.net.OutboundConnectionInitiator.SslFallbackConnectionType.SERVER_CONFIG;
 
 public class NettyStreamingConnectionFactory implements 
StreamingChannel.Factory
 {
@@ -45,27 +51,38 @@ public class NettyStreamingConnectionFactory implements 
StreamingChannel.Factory
     public static NettyStreamingChannel connect(OutboundConnectionSettings 
template, int messagingVersion, StreamingChannel.Kind kind) throws IOException
     {
         EventLoop eventLoop = 
MessagingService.instance().socketFactory.outboundStreamingGroup().next();
+        OutboundConnectionSettings settings = 
template.withDefaults(ConnectionCategory.STREAMING);
+        List<SslFallbackConnectionType> sslFallbacks = 
settings.withEncryption() && settings.encryption.getOptional()
+                                                       ? 
Arrays.asList(SslFallbackConnectionType.values())
+                                                       : 
Collections.singletonList(SERVER_CONFIG);
 
-        int attempts = 0;
-        while (true)
+        Throwable cause = null;
+        for (final SslFallbackConnectionType sslFallbackConnectionType : 
sslFallbacks)
         {
-            Future<Result<StreamingSuccess>> result = 
initiateStreaming(eventLoop, 
template.withDefaults(ConnectionCategory.STREAMING), messagingVersion);
-            result.awaitUninterruptibly(); // initiate has its own timeout, so 
this is "guaranteed" to return relatively promptly
-            if (result.isSuccess())
+            for (int i = 0; i < MAX_CONNECT_ATTEMPTS; i++)
             {
-                Channel channel = result.getNow().success().channel;
-                NettyStreamingChannel streamingChannel = new 
NettyStreamingChannel(messagingVersion, channel, kind);
-                if (kind == StreamingChannel.Kind.CONTROL)
+                Future<Result<StreamingSuccess>> result = 
initiateStreaming(eventLoop, settings, sslFallbackConnectionType, 
messagingVersion);
+                result.awaitUninterruptibly(); // initiate has its own 
timeout, so this is "guaranteed" to return relatively promptly
+                if (result.isSuccess())
                 {
-                    ChannelPipeline pipeline = channel.pipeline();
-                    pipeline.addLast("stream", streamingChannel);
+                    Channel channel = result.getNow().success().channel;
+                    NettyStreamingChannel streamingChannel = new 
NettyStreamingChannel(messagingVersion, channel, kind);
+                    if (kind == StreamingChannel.Kind.CONTROL)
+                    {
+                        ChannelPipeline pipeline = channel.pipeline();
+                        pipeline.addLast("stream", streamingChannel);
+                    }
+                    return streamingChannel;
                 }
-                return streamingChannel;
+                cause = result.cause();
+            }
+            if (!isSSLError(cause))
+            {
+                // Fallback only when the error is SSL related, otherwise 
retries are exhausted, so fail
+                break;
             }
-
-            if (++attempts == MAX_CONNECT_ATTEMPTS)
-                throw new IOException("failed to connect to " + template.to + 
" for streaming data", result.cause());
         }
+        throw new IOException("failed to connect to " + template.to + " for 
streaming data", cause);
     }
 
     @Override
diff --git a/test/conf/cassandra_ssl_test.truststore 
b/test/conf/cassandra_ssl_test.truststore
index 49cf3323e5..5ba9a9977c 100644
Binary files a/test/conf/cassandra_ssl_test.truststore and 
b/test/conf/cassandra_ssl_test.truststore differ
diff --git 
a/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java
 
b/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java
index 156a6b4b64..d13e2e4a0c 100644
--- 
a/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java
+++ 
b/test/distributed/org/apache/cassandra/distributed/test/InternodeEncryptionEnforcementTest.java
@@ -189,16 +189,18 @@ public final class InternodeEncryptionEnforcementTest 
extends TestBaseImpl
                 c.with(Feature.NETWORK);
                 c.with(Feature.NATIVE_PROTOCOL);
 
+                HashMap<String, Object> encryption = new HashMap<>();
+                encryption.put("optional", "false");
+                encryption.put("internode_encryption", "none");
                 if (c.num() == 1)
                 {
-                    HashMap<String, Object> encryption = new HashMap<>();
                     encryption.put("keystore", 
"test/conf/cassandra_ssl_test.keystore");
                     encryption.put("keystore_password", "cassandra");
                     encryption.put("truststore", 
"test/conf/cassandra_ssl_test.truststore");
                     encryption.put("truststore_password", "cassandra");
-                    encryption.put("internode_encryption", "dc");
-                    c.set("server_encryption_options", encryption);
+                    encryption.put("internode_encryption", "all");
                 }
+                c.set("server_encryption_options", encryption);
             })
             .withNodeIdTopology(ImmutableMap.of(1, 
NetworkTopology.dcAndRack("dc1", "r1a"),
                                                 2, 
NetworkTopology.dcAndRack("dc2", "r2a")));
diff --git a/test/unit/org/apache/cassandra/net/HandshakeTest.java 
b/test/unit/org/apache/cassandra/net/HandshakeTest.java
index 75ae1034c5..6a0f7d379a 100644
--- a/test/unit/org/apache/cassandra/net/HandshakeTest.java
+++ b/test/unit/org/apache/cassandra/net/HandshakeTest.java
@@ -19,10 +19,20 @@
 package org.apache.cassandra.net;
 
 import java.nio.channels.ClosedChannelException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
 import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 
+import com.google.common.net.InetAddresses;
+
+import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions;
+import org.apache.cassandra.config.ParameterizedClass;
+import org.apache.cassandra.gms.GossipDigestSyn;
+import org.apache.cassandra.security.DefaultSslContextFactory;
 import org.apache.cassandra.utils.concurrent.AsyncPromise;
 import org.junit.AfterClass;
 import org.junit.Assert;
@@ -42,11 +52,15 @@ import static 
org.apache.cassandra.net.MessagingService.current_version;
 import static org.apache.cassandra.net.MessagingService.minimum_version;
 import static org.apache.cassandra.net.ConnectionType.SMALL_MESSAGES;
 import static org.apache.cassandra.net.OutboundConnectionInitiator.*;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 
 // TODO: test failure due to exception, timeout, etc
 public class HandshakeTest
 {
     private static final SocketFactory factory = new SocketFactory();
+    static final InetAddressAndPort TO_ADDR = 
InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.2"),
 7012);
+    static final InetAddressAndPort FROM_ADDR = 
InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.1"),
 7012);
 
     @BeforeClass
     public static void startup()
@@ -80,6 +94,7 @@ public class HandshakeTest
             Future<Result<MessagingSuccess>> future =
             initiateMessaging(eventLoop,
                               SMALL_MESSAGES,
+                              SslFallbackConnectionType.SERVER_CONFIG,
                               new OutboundConnectionSettings(endpoint)
                                                     
.withAcceptVersions(acceptOutbound)
                                                     
.withDefaults(ConnectionCategory.MESSAGING),
@@ -92,6 +107,7 @@ public class HandshakeTest
         }
     }
 
+
     @Test
     public void testBothCurrentVersion() throws InterruptedException, 
ExecutionException
     {
@@ -172,7 +188,7 @@ public class HandshakeTest
         }
         catch (ExecutionException e)
         {
-            Assert.assertTrue(e.getCause() instanceof ClosedChannelException);
+            assertTrue(e.getCause() instanceof ClosedChannelException);
         }
     }
 
@@ -186,7 +202,7 @@ public class HandshakeTest
         }
         catch (ExecutionException e)
         {
-            Assert.assertTrue(e.getCause() instanceof ClosedChannelException);
+            assertTrue(e.getCause() instanceof ClosedChannelException);
         }
     }
 
@@ -207,7 +223,7 @@ public class HandshakeTest
         }
         catch (ExecutionException e)
         {
-            Assert.assertTrue(e.getCause() instanceof ClosedChannelException);
+            assertTrue(e.getCause() instanceof ClosedChannelException);
         }
     }
 
@@ -218,4 +234,167 @@ public class HandshakeTest
         Assert.assertEquals(Result.Outcome.SUCCESS, result.outcome);
         Assert.assertEquals(VERSION_30, result.success().messagingVersion);
     }
+
+    @Test
+    public void testOutboundConnectionfFallbackDuringUpgrades() throws 
ClosedChannelException, InterruptedException
+    {
+        // Upgrade from Non-SSL -> Optional SSL
+        // Outbound connection from Optional SSL(new node) -> Non-SSL (old 
node)
+        
testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.SSL, true, 
SslFallbackConnectionType.NO_SSL, false);
+
+        // Upgrade from Optional SSL -> Strict SSL
+        // Outbound connection from Strict SSL(new node) -> Optional SSL (old 
node)
+        
testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.SSL, false, 
SslFallbackConnectionType.SSL, true);
+
+        // Upgrade from Optional SSL -> Strict MTLS
+        // Outbound connection from Strict MTLS(new node) -> Optional SSL (old 
node)
+        
testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.MTLS, 
false, SslFallbackConnectionType.SSL, true);
+
+        // Upgrade from Strict SSL -> Optional MTLS
+        // Outbound connection from Optional MTLS(new node) -> Strict SSL (old 
node)
+        
testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.MTLS, true, 
SslFallbackConnectionType.SSL, false);
+
+        // Upgrade from Strict Optional MTLS -> Strict MTLS
+        // Outbound connection from Strict TLS(new node) -> Optional TLS (old 
node)
+        
testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.MTLS, 
false, SslFallbackConnectionType.MTLS, true);
+    }
+
+    @Test
+    public void testOutboundConnectionfFallbackDuringDowngrades() throws 
ClosedChannelException, InterruptedException
+    {
+        // From Strict MTLS -> Optional MTLS
+        // Outbound connection from Optional TLS(new node) -> Strict MTLS (old 
node)
+        
testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.MTLS, true, 
SslFallbackConnectionType.MTLS, false);
+
+        // From Optional MTLS -> Strict SSL
+        // Outbound connection from Strict SSL(new node) -> Optional MTLS (old 
node)
+        
testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.SSL, false, 
SslFallbackConnectionType.MTLS, true);
+
+        // From Strict MTLS -> Optional SSL
+        // Outbound connection from Optional SSL(new node) -> Strict MTLS (old 
node)
+        
testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.SSL, true, 
SslFallbackConnectionType.MTLS, false);
+
+        // From Strict SSL -> Optional SSL
+        // Outbound connection from Optional SSL(new node) -> Strict SSL (old 
node)
+        
testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.SSL, true, 
SslFallbackConnectionType.SSL, false);
+
+        // From Optional SSL -> Non-SSL
+        // Outbound connection from Non-SSL(new node) -> Optional SSL (old 
node)
+        
testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType.NO_SSL, 
false, SslFallbackConnectionType.SSL, true);
+    }
+
+    @Test
+    public void testOutboundConnectionDoesntFallbackWhenErrorIsNotSSLRelated() 
throws ClosedChannelException, InterruptedException
+    {
+        // Configuring nodes in Optional SSL mode
+        // when optional mode is enabled, if the connection error is SSL 
related, fallback to another SSL strategy should happen,
+        // otherwise it should use same SSL strategy and retry
+        ServerEncryptionOptions serverEncryptionOptions = 
getServerEncryptionOptions(SslFallbackConnectionType.SSL, true);
+        InboundSockets inbound = getInboundSocket(serverEncryptionOptions);
+        try
+        {
+            InetAddressAndPort endpoint = inbound.sockets().stream().map(s -> 
s.settings.bindAddress).findFirst().get();
+
+            // Open outbound connections before server starts listening
+            // The connection should be accepted after opening inbound 
connections, with the same SSL context without fallback
+            OutboundConnection outboundConnection = initiateOutbound(endpoint, 
SslFallbackConnectionType.SSL, true);
+
+            // Let the outbound connection be tried for 4 times atleast
+            while (outboundConnection.connectionAttempts() < 
SslFallbackConnectionType.values().length)
+            {
+                Thread.sleep(1000);
+            }
+            assertFalse(outboundConnection.isConnected());
+            inbound.open();
+            // As soon as the node accepts inbound connections, the connection 
must be established with right SSL context
+            waitForConnection(outboundConnection);
+            assertTrue(outboundConnection.isConnected());
+            assertFalse(outboundConnection.hasPending());
+        }
+        finally
+        {
+            inbound.close().await(10L, TimeUnit.SECONDS);
+        }
+    }
+
+    private ServerEncryptionOptions 
getServerEncryptionOptions(SslFallbackConnectionType sslConnectionType, boolean 
optional)
+    {
+        ServerEncryptionOptions serverEncryptionOptions = new 
ServerEncryptionOptions().withOptional(optional)
+                                                                               
        .withKeyStore("test/conf/cassandra_ssl_test.keystore")
+                                                                               
        .withKeyStorePassword("cassandra")
+                                                                               
        .withOutboundKeystore("test/conf/cassandra_ssl_test_outbound.keystore")
+                                                                               
        .withOutboundKeystorePassword("cassandra")
+                                                                               
        .withTrustStore("test/conf/cassandra_ssl_test.truststore")
+                                                                               
        .withTrustStorePassword("cassandra")
+                                                                               
        .withSslContextFactory((new 
ParameterizedClass(DefaultSslContextFactory.class.getName(),
+                                                                               
                                                       new HashMap<>())));
+        if (sslConnectionType == SslFallbackConnectionType.MTLS)
+        {
+            serverEncryptionOptions = 
serverEncryptionOptions.withInternodeEncryption(ServerEncryptionOptions.InternodeEncryption.all)
+                                                             
.withRequireClientAuth(true);
+        }
+        else if (sslConnectionType == SslFallbackConnectionType.SSL)
+        {
+            serverEncryptionOptions = 
serverEncryptionOptions.withInternodeEncryption(ServerEncryptionOptions.InternodeEncryption.all)
+                                                             
.withRequireClientAuth(false);
+        }
+        return serverEncryptionOptions;
+    }
+
+    private InboundSockets getInboundSocket(ServerEncryptionOptions 
serverEncryptionOptions)
+    {
+        InboundConnectionSettings settings = new 
InboundConnectionSettings().withAcceptMessaging(new 
AcceptVersions(minimum_version, current_version))
+                                                                            
.withEncryption(serverEncryptionOptions)
+                                                                            
.withBindAddress(TO_ADDR);
+        List<InboundConnectionSettings> settingsList =  new ArrayList<>();
+        settingsList.add(settings);
+        return new InboundSockets(settingsList);
+    }
+
+    private OutboundConnection initiateOutbound(InetAddressAndPort endpoint, 
SslFallbackConnectionType connectionType, boolean optional) throws 
ClosedChannelException
+    {
+        final OutboundConnectionSettings settings = new 
OutboundConnectionSettings(endpoint)
+        .withAcceptVersions(new AcceptVersions(minimum_version, 
current_version))
+        .withDefaults(ConnectionCategory.MESSAGING)
+        .withEncryption(getServerEncryptionOptions(connectionType, optional))
+        .withFrom(FROM_ADDR);
+        OutboundConnections outboundConnections = 
OutboundConnections.tryRegister(new ConcurrentHashMap<>(), TO_ADDR, settings);
+        GossipDigestSyn syn = new GossipDigestSyn("cluster", "partitioner", 
new ArrayList<>(0));
+        Message<GossipDigestSyn> message = Message.out(Verb.GOSSIP_DIGEST_SYN, 
syn);
+        OutboundConnection outboundConnection = 
outboundConnections.connectionFor(message);
+        outboundConnection.enqueue(message);
+        outboundConnection.initiate();
+        return outboundConnection;
+    }
+
+    private void 
testOutboundFallbackOnSSLHandshakeFailure(SslFallbackConnectionType 
fromConnectionType, boolean fromOptional,
+                                                           
SslFallbackConnectionType toConnectionType, boolean toOptional) throws 
ClosedChannelException, InterruptedException
+    {
+        // Configures inbound connections to be optional mTLS
+        InboundSockets inbound = 
getInboundSocket(getServerEncryptionOptions(toConnectionType, toOptional));
+        try
+        {
+            InetAddressAndPort endpoint = inbound.sockets().stream().map(s -> 
s.settings.bindAddress).findFirst().get();
+            inbound.open();
+
+            // Open outbound connections, and wait until connection is 
established
+            OutboundConnection outboundConnection = initiateOutbound(endpoint, 
fromConnectionType, fromOptional);
+            waitForConnection(outboundConnection);
+            assertTrue(outboundConnection.isConnected());
+            assertFalse(outboundConnection.hasPending());
+        }
+        finally
+        {
+            inbound.close().await(10L, TimeUnit.SECONDS);
+        }
+    }
+
+    private void waitForConnection(OutboundConnection outboundConnection) 
throws InterruptedException
+    {
+        long startTime = System.currentTimeMillis();
+        while (!outboundConnection.isConnected() && System.currentTimeMillis() 
- startTime < 60000)
+        {
+            Thread.sleep(1000);
+        }
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org
For additional commands, e-mail: commits-h...@cassandra.apache.org

Reply via email to