This is an automated email from the ASF dual-hosted git repository. mridulm80 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2709426f0f6 [SPARK-45541][CORE] Add SSLFactory 2709426f0f6 is described below commit 2709426f0f622a214c51954664458e7dd2ab3304 Author: Hasnain Lakhani <hasnain.lakh...@databricks.com> AuthorDate: Sun Oct 22 03:21:19 2023 -0500 [SPARK-45541][CORE] Add SSLFactory ### What changes were proposed in this pull request? As titled - add a factory which supports creating SSL engines, and a corresponding builder for it. This will be used in a follow up PR by the `TransportContext` and related files to add SSL support. ### Why are the changes needed? We need a mechanism to initialize the appropriate SSL implementation with the configured settings (such as protocol, ciphers, etc) for RPC SSL support. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing tests. This will be more thoroughly tested in a follow up PR which adds callsites to it. It has been integration tested as part of https://github.com/apache/spark/pull/42685 ### Was this patch authored or co-authored using generative AI tooling? No Closes #43386 from hasnain-db/spark-tls-factory. Authored-by: Hasnain Lakhani <hasnain.lakh...@databricks.com> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com> --- .../spark/network/protocol/SslMessageEncoder.java | 2 +- .../network/ssl/ReloadingX509TrustManager.java | 5 +- .../org/apache/spark/network/ssl/SSLFactory.java | 470 +++++++++++++++++++++ 3 files changed, 474 insertions(+), 3 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/SslMessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/SslMessageEncoder.java index f43d0789ee6..87723c6613e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/SslMessageEncoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/SslMessageEncoder.java @@ -36,7 +36,7 @@ import org.slf4j.LoggerFactory; @ChannelHandler.Sharable public final class SslMessageEncoder extends MessageToMessageEncoder<Message> { - private final Logger logger = LoggerFactory.getLogger(SslMessageEncoder.class); + private static final Logger logger = LoggerFactory.getLogger(SslMessageEncoder.class); private SslMessageEncoder() {} diff --git a/common/network-common/src/main/java/org/apache/spark/network/ssl/ReloadingX509TrustManager.java b/common/network-common/src/main/java/org/apache/spark/network/ssl/ReloadingX509TrustManager.java index 4c39a5d2a3d..87798bda2a0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/ssl/ReloadingX509TrustManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/ssl/ReloadingX509TrustManager.java @@ -46,7 +46,7 @@ import org.slf4j.LoggerFactory; public final class ReloadingX509TrustManager implements X509TrustManager, Runnable { - private final Logger logger = LoggerFactory.getLogger(ReloadingX509TrustManager.class); + private static final Logger logger = LoggerFactory.getLogger(ReloadingX509TrustManager.class); private final String type; private final File file; @@ -180,7 +180,8 @@ public final class ReloadingX509TrustManager canonicalPath = latestCanonicalFile.getPath(); lastLoaded = latestCanonicalFile.lastModified(); try (FileInputStream in = new FileInputStream(latestCanonicalFile)) { - ks.load(in, password.toCharArray()); + char[] passwordCharacters = password != null? password.toCharArray() : null; + ks.load(in, passwordCharacters); logger.debug("Loaded truststore '" + file + "'"); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/ssl/SSLFactory.java b/common/network-common/src/main/java/org/apache/spark/network/ssl/SSLFactory.java new file mode 100644 index 00000000000..fc03dba617f --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/ssl/SSLFactory.java @@ -0,0 +1,470 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.ssl; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + +import com.google.common.io.Files; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.handler.ssl.OpenSsl; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProvider; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.util.JavaUtils; + +public class SSLFactory { + private static final Logger logger = LoggerFactory.getLogger(SSLFactory.class); + + /** + * For a configuration specifying keystore/truststore files + */ + private SSLContext jdkSslContext; + + /** + * For a configuration specifying a PEM cert chain, and a PEM private key + */ + private SslContext nettyClientSslContext; + private SslContext nettyServerSslContext; + + private KeyManager[] keyManagers; + private TrustManager[] trustManagers; + private String requestedProtocol; + private String[] requestedCiphers; + + private SSLFactory(final Builder b) { + this.requestedProtocol = b.requestedProtocol; + this.requestedCiphers = b.requestedCiphers; + try { + if (b.certChain != null && b.privateKey != null) { + initNettySslContexts(b); + } else { + initJdkSslContext(b); + } + } catch (Exception e) { + throw new RuntimeException("SSLFactory creation failed", e); + } + } + + private void initJdkSslContext(final Builder b) + throws IOException, GeneralSecurityException { + this.keyManagers = keyManagers(b.keyStore, b.keyStorePassword); + this.trustManagers = trustStoreManagers( + b.trustStore, b.trustStorePassword, + b.trustStoreReloadingEnabled, b.trustStoreReloadIntervalMs + ); + this.jdkSslContext = createSSLContext(requestedProtocol, keyManagers, trustManagers); + } + + private void initNettySslContexts(final Builder b) + throws SSLException { + nettyClientSslContext = SslContextBuilder + .forClient() + .sslProvider(getSslProvider(b)) + .trustManager(b.certChain) + .build(); + + nettyServerSslContext = SslContextBuilder + .forServer(b.certChain, b.privateKey, b.keyPassword) + .sslProvider(getSslProvider(b)) + .build(); + } + + /** + * If OpenSSL is requested, this will check if an implementation is available on the local host. + * If an implementation is not available it will fall back to the JDK {@link SslProvider}. + * + * @param b + * @return + */ + private SslProvider getSslProvider(Builder b) { + if (b.openSslEnabled) { + if (OpenSsl.isAvailable()) { + return SslProvider.OPENSSL; + } else { + logger.warn("OpenSSL Provider requested but it is not available, using JDK SSL Provider"); + } + } + return SslProvider.JDK; + } + + public void destroy() { + if (trustManagers != null) { + for (int i = 0; i < trustManagers.length; i++) { + if (trustManagers[i] instanceof ReloadingX509TrustManager) { + try { + ((ReloadingX509TrustManager) trustManagers[i]).destroy(); + } catch (InterruptedException ex) { + logger.info("Interrupted while destroying trust manager: " + ex.toString(), ex); + } + } + } + trustManagers = null; + } + + keyManagers = null; + jdkSslContext = null; + nettyClientSslContext = null; + nettyServerSslContext = null; + requestedProtocol = null; + requestedCiphers = null; + } + + /** + * Builder class to construct instances of {@link SSLFactory} with specific options + */ + public static class Builder { + private String requestedProtocol; + private String[] requestedCiphers; + private File keyStore; + private String keyStorePassword; + private File privateKey; + private String keyPassword; + private File certChain; + private File trustStore; + private String trustStorePassword; + private boolean trustStoreReloadingEnabled; + private int trustStoreReloadIntervalMs; + private boolean openSslEnabled; + + /** + * Sets the requested protocol, i.e., "TLSv1.2", "TLSv1.1", etc + * + * @param requestedProtocol The requested protocol + * @return The builder object + */ + public Builder requestedProtocol(String requestedProtocol) { + this.requestedProtocol = requestedProtocol == null ? "TLSv1.2" : requestedProtocol; + return this; + } + + /** + * Sets the requested cipher suites, i.e., "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", etc + * + * @param requestedCiphers The requested ciphers + * @return The builder object + */ + public Builder requestedCiphers(String[] requestedCiphers) { + this.requestedCiphers = requestedCiphers; + return this; + } + + /** + * Sets the Keystore and Keystore password + * + * @param keyStore The key store file to use + * @param keyStorePassword The password for the key store + * @return The builder object + */ + public Builder keyStore(File keyStore, String keyStorePassword) { + this.keyStore = keyStore; + this.keyStorePassword = keyStorePassword; + return this; + } + + /** + * Sets a PKCS#8 private key file in PEM format + * + * @param privateKey The private key file to use + * @return The builder object + */ + public Builder privateKey(File privateKey) { + this.privateKey = privateKey; + return this; + } + + /** + * Sets the Key password + * + * @param keyPassword The password for the private key + * @return The builder object + */ + public Builder keyPassword(String keyPassword) { + this.keyPassword = keyPassword; + return this; + } + + /** + * Sets a X.509 certificate chain file in PEM format + * + * @param certChain The certificate chain file to use + * @return The builder object + */ + public Builder certChain(File certChain) { + this.certChain = certChain; + return this; + } + + /** + * @param enabled Whether to use the OpenSSL implementation + * @return The builder object + */ + public Builder openSslEnabled(boolean enabled) { + this.openSslEnabled = enabled; + return this; + } + + /** + * Sets the trust-store, trust-store password, whether to use a Reloading TrustStore, + * and the trust-store reload interval, if enabled + * + * @param trustStore The trust store file to use + * @param trustStorePassword The password for the trust store + * @param trustStoreReloadingEnabled Whether trust store reloading is enabled + * @param trustStoreReloadIntervalMs The interval at which to reload the trust store file + * @return The builder object + */ + public Builder trustStore( + File trustStore, String trustStorePassword, + boolean trustStoreReloadingEnabled, int trustStoreReloadIntervalMs) { + this.trustStore = trustStore; + this.trustStorePassword = trustStorePassword; + this.trustStoreReloadingEnabled = trustStoreReloadingEnabled; + this.trustStoreReloadIntervalMs = trustStoreReloadIntervalMs; + return this; + } + + /** + * Builds our {@link SSLFactory} + * + * @return The built {@link SSLFactory} + */ + public SSLFactory build() { + return new SSLFactory(this); + } + } + + /** + * Returns an initialized {@link SSLContext} + * + * @param requestedProtocol The requested protocol to use + * @param keyManagers The list of key managers to use + * @param trustManagers The list of trust managers to use + * @return The built {@link SSLContext} + * @throws GeneralSecurityException + */ + private static SSLContext createSSLContext( + String requestedProtocol, + KeyManager[] keyManagers, + TrustManager[] trustManagers) throws GeneralSecurityException { + SSLContext sslContext = SSLContext.getInstance(requestedProtocol); + sslContext.init(keyManagers, trustManagers, null); + return sslContext; + } + + /** + * Creates a new {@link SSLEngine}. + * Note that currently client auth is not supported + * + * @param isClient Whether the engine is used in a client context + * @param allocator The {@link ByteBufAllocator to use} + * @return A valid {@link SSLEngine}. + */ + public SSLEngine createSSLEngine(boolean isClient, ByteBufAllocator allocator) { + SSLEngine engine = createEngine(isClient, allocator); + engine.setUseClientMode(isClient); + engine.setNeedClientAuth(false); + engine.setEnabledProtocols(enabledProtocols(engine, requestedProtocol)); + engine.setEnabledCipherSuites(enabledCipherSuites(engine, requestedCiphers)); + return engine; + } + + private SSLEngine createEngine(boolean isClient, ByteBufAllocator allocator) { + SSLEngine engine; + if (isClient) { + if (nettyClientSslContext != null) { + engine = nettyClientSslContext.newEngine(allocator); + } else { + engine = jdkSslContext.createSSLEngine(); + } + } else { + if (nettyServerSslContext != null) { + engine = nettyServerSslContext.newEngine(allocator); + } else { + engine = jdkSslContext.createSSLEngine(); + } + } + return engine; + } + + private static TrustManager[] credulousTrustStoreManagers() { + return new TrustManager[]{new X509TrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return null; + } + }}; + } + + private static TrustManager[] trustStoreManagers( + File trustStore, String trustStorePassword, + boolean trustStoreReloadingEnabled, int trustStoreReloadIntervalMs) + throws IOException, GeneralSecurityException { + if (trustStore == null || !trustStore.exists()) { + return credulousTrustStoreManagers(); + } else { + if (trustStoreReloadingEnabled) { + ReloadingX509TrustManager reloading = new ReloadingX509TrustManager( + KeyStore.getDefaultType(), trustStore, trustStorePassword, trustStoreReloadIntervalMs); + reloading.init(); + return new TrustManager[]{reloading}; + } else { + return defaultTrustManagers(trustStore, trustStorePassword); + } + } + } + + private static TrustManager[] defaultTrustManagers(File trustStore, String trustStorePassword) + throws IOException, KeyStoreException, CertificateException, NoSuchAlgorithmException { + try (InputStream input = Files.asByteSource(trustStore).openStream()) { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + char[] passwordCharacters = trustStorePassword != null? + trustStorePassword.toCharArray() : null; + ks.load(input, passwordCharacters); + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(ks); + return tmf.getTrustManagers(); + } + } + + private static KeyManager[] keyManagers(File keyStore, String keyStorePassword) + throws NoSuchAlgorithmException, CertificateException, + KeyStoreException, IOException, UnrecoverableKeyException { + KeyManagerFactory factory = KeyManagerFactory.getInstance( + KeyManagerFactory.getDefaultAlgorithm()); + char[] passwordCharacters = keyStorePassword != null? keyStorePassword.toCharArray() : null; + factory.init(loadKeyStore(keyStore, passwordCharacters), passwordCharacters); + return factory.getKeyManagers(); + } + + private static KeyStore loadKeyStore(File keyStore, char[] keyStorePassword) + throws KeyStoreException, IOException, CertificateException, NoSuchAlgorithmException { + if (keyStore == null) { + throw new KeyStoreException( + "keyStore cannot be null. Please configure spark.ssl.rpc.keyStore"); + } + + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + FileInputStream fin = new FileInputStream(keyStore); + try { + ks.load(fin, keyStorePassword); + return ks; + } finally { + JavaUtils.closeQuietly(fin); + } + } + + private static String[] enabledProtocols(SSLEngine engine, String requestedProtocol) { + String[] supportedProtocols = engine.getSupportedProtocols(); + String[] defaultProtocols = {"TLSv1.3", "TLSv1.2"}; + String[] enabledProtocols = + ((requestedProtocol == null || requestedProtocol.isEmpty()) ? + defaultProtocols : new String[]{requestedProtocol}); + + List<String> protocols = addIfSupported(supportedProtocols, enabledProtocols); + if (!protocols.isEmpty()) { + return protocols.toArray(new String[protocols.size()]); + } else { + return supportedProtocols; + } + } + + private static String[] enabledCipherSuites( + String[] supportedCiphers, String[] defaultCiphers, String[] requestedCiphers) { + String[] baseCiphers = new String[]{ + // We take ciphers from the mozilla modern list first (for TLS 1.3): + // https://wiki.mozilla.org/Security/Server_Side_TLS + "TLS_CHACHA20_POLY1305_SHA256", + "TLS_AES_128_GCM_SHA256", + "TLS_AES_256_GCM_SHA384", + // Next we have the TLS1.2 ciphers for intermediate compatibility (since JDK8 does not + // support TLS1.3) + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", + "TLS_DHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384" + }; + String[] enabledCiphers = + ((requestedCiphers == null || requestedCiphers.length == 0) ? baseCiphers : requestedCiphers); + + List<String> ciphers = addIfSupported(supportedCiphers, enabledCiphers); + if (!ciphers.isEmpty()) { + return ciphers.toArray(new String[ciphers.size()]); + } else { + // Use the default from JDK as fallback. + return defaultCiphers; + } + } + + private static String[] enabledCipherSuites(SSLEngine engine, String[] requestedCiphers) { + return enabledCipherSuites( + engine.getSupportedCipherSuites(), engine.getEnabledCipherSuites(), requestedCiphers); + } + + private static List<String> addIfSupported(String[] supported, String... names) { + List<String> enabled = new ArrayList<>(); + Set<String> supportedSet = new HashSet<>(Arrays.asList(supported)); + for (String n : names) { + if (supportedSet.contains(n)) { + enabled.add(n); + } + } + return enabled; + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org