This is an automated email from the ASF dual-hosted git repository. frankgh pushed a commit to branch trunk in repository https://gitbox.apache.org/repos/asf/cassandra-analytics.git
The following commit(s) were added to refs/heads/trunk by this push: new 6901018 CASSANDRA-19526: Optionally enable TLS in the server and client for Analytics testing 6901018 is described below commit 690101840d4d8f9c656bb0ca114f6619af80e1cf Author: Francisco Guerrero <fran...@apache.org> AuthorDate: Mon Apr 8 14:33:50 2024 -0700 CASSANDRA-19526: Optionally enable TLS in the server and client for Analytics testing All integration tests today run without TLS, which is generally fine because they run locally. However, it is helpful to be able to start up the sidecar with TLS enabled in the integration test framework so that third-party tests could connect via secure connections for testing purposes. Co-authored-by: Doug Rohrer <droh...@apple.com> Co-authored-by: Francisco Guerrero <fran...@apache.org> Patch by Doug Rohrer, Francisco Guerrero; Reviewed by Yifan Cai for CASSANDRA-19526 --- .../spark/common/stats/JobStatsPublisher.java | 2 + .../build.gradle | 6 +- .../distributed/impl/CassandraCluster.java | 1 + .../cassandra/sidecar/testing/MtlsTestHelper.java | 146 ++++++++ .../testing/SharedClusterIntegrationTestBase.java | 54 ++- .../testing/utils/tls/CertificateBuilder.java | 236 +++++++++++++ .../testing/utils/tls/CertificateBundle.java | 112 ++++++ cassandra-analytics-integration-tests/build.gradle | 4 + .../cassandra/analytics/BlockedInstancesTest.java | 29 +- .../cassandra/analytics/DataGenerationUtils.java | 50 ++- .../cassandra/analytics/IntegrationTestJob.java | 379 --------------------- .../SharedClusterSparkIntegrationTestBase.java | 40 ++- .../analytics/SparkBulkWriterSimpleTest.java | 118 ++----- .../apache/cassandra/analytics/SparkTestUtils.java | 23 +- 14 files changed, 710 insertions(+), 490 deletions(-) diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/stats/JobStatsPublisher.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/stats/JobStatsPublisher.java index 28643e8..9027ce4 100644 --- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/stats/JobStatsPublisher.java +++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/stats/JobStatsPublisher.java @@ -30,6 +30,8 @@ public interface JobStatsPublisher { /** * Publish the job attributes to be persisted and summarized + * + * @param stats the stats to publish */ void publish(Map<String, String> stats); } diff --git a/cassandra-analytics-integration-framework/build.gradle b/cassandra-analytics-integration-framework/build.gradle index aeba617..63d66c7 100644 --- a/cassandra-analytics-integration-framework/build.gradle +++ b/cassandra-analytics-integration-framework/build.gradle @@ -75,7 +75,11 @@ dependencies { exclude group: 'junit', module: 'junit' } implementation("io.vertx:vertx-web-client:${project.vertxVersion}") - implementation group: 'com.fasterxml.jackson.core', name: 'jackson-annotations', version: '2.14.3' + implementation(group: 'com.fasterxml.jackson.core', name: 'jackson-annotations', version: '2.14.3') + + // Bouncycastle dependencies for test certificate provisioning + implementation(group: 'org.bouncycastle', name: 'bcprov-jdk18on', version: '1.78') + implementation(group: 'org.bouncycastle', name: 'bcpkix-jdk18on', version: '1.78') testImplementation(platform("org.junit:junit-bom:${project.junitVersion}")) testImplementation('org.junit.jupiter:junit-jupiter') diff --git a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/distributed/impl/CassandraCluster.java b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/distributed/impl/CassandraCluster.java index 4d20c62..f5a5abd 100644 --- a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/distributed/impl/CassandraCluster.java +++ b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/distributed/impl/CassandraCluster.java @@ -68,6 +68,7 @@ public class CassandraCluster<I extends IInstance> implements IClusterExtension< // java.lang.IllegalStateException: Can't load <CLASS>. Instance class loader is already closed. return className.equals("org.apache.cassandra.utils.concurrent.Ref$OnLeak") || className.startsWith("org.apache.cassandra.metrics.RestorableMeter") + || className.equals("org.apache.logging.slf4j.EventDataConverter") || (className.startsWith("org.apache.cassandra.analytics.") && className.contains("BBHelper")); }; diff --git a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/MtlsTestHelper.java b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/MtlsTestHelper.java new file mode 100644 index 0000000..55f200f --- /dev/null +++ b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/MtlsTestHelper.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.sidecar.testing; + +import java.nio.file.Path; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.testing.utils.tls.CertificateBuilder; +import org.apache.cassandra.testing.utils.tls.CertificateBundle; + +/** + * A class that encapsulates testing with Mutual TLS. + */ +public class MtlsTestHelper +{ + private static final Logger LOGGER = LoggerFactory.getLogger(MtlsTestHelper.class); + public static final char[] EMPTY_PASSWORD = new char[0]; + public static final String EMPTY_PASSWORD_STRING = ""; + /** + * A system property that can enable / disable testing with Mutual TLS + */ + public static final String CASSANDRA_INTEGRATION_TEST_ENABLE_MTLS = "cassandra.integration.sidecar.test.enable_mtls"; + private final boolean enableMtlsForTesting; + CertificateBundle certificateAuthority; + Path truststorePath; + Path serverKeyStorePath; + Path clientKeyStorePath; + + public MtlsTestHelper(Path secretsPath) throws Exception + { + this(secretsPath, System.getProperty(CASSANDRA_INTEGRATION_TEST_ENABLE_MTLS, "false").equals("true")); + } + + public MtlsTestHelper(Path secretsPath, boolean enableMtlsForTesting) throws Exception + { + this.enableMtlsForTesting = enableMtlsForTesting; + maybeInitializeSecrets(Objects.requireNonNull(secretsPath, "secretsPath cannot be null")); + } + + void maybeInitializeSecrets(Path secretsPath) throws Exception + { + if (!enableMtlsForTesting) + { + return; + } + + certificateAuthority = + new CertificateBuilder().subject("CN=Apache Cassandra Root CA, OU=Certification Authority, O=Unknown, C=Unknown") + .alias("fakerootca") + .isCertificateAuthority(true) + .buildSelfSigned(); + truststorePath = certificateAuthority.toTempKeyStorePath(secretsPath, EMPTY_PASSWORD, EMPTY_PASSWORD); + + CertificateBuilder serverKeyStoreBuilder = + new CertificateBuilder().subject("CN=Apache Cassandra, OU=mtls_test, O=Unknown, L=Unknown, ST=Unknown, C=Unknown") + .addSanDnsName("localhost"); + // Add SANs for every potential hostname Sidecar will serve + for (int i = 1; i <= 20; i++) + { + serverKeyStoreBuilder.addSanDnsName("localhost" + i); + } + + CertificateBundle serverKeyStore = serverKeyStoreBuilder.buildIssuedBy(certificateAuthority); + serverKeyStorePath = serverKeyStore.toTempKeyStorePath(secretsPath, EMPTY_PASSWORD, EMPTY_PASSWORD); + CertificateBundle clientKeyStore = new CertificateBuilder().subject("CN=Apache Cassandra, OU=mtls_test, O=Unknown, L=Unknown, ST=Unknown, C=Unknown") + .addSanDnsName("localhost") + .buildIssuedBy(certificateAuthority); + clientKeyStorePath = clientKeyStore.toTempKeyStorePath(secretsPath, EMPTY_PASSWORD, EMPTY_PASSWORD); + } + + public boolean isEnabled() + { + return enableMtlsForTesting; + } + + public String trustStorePath() + { + return truststorePath.toString(); + } + + public String trustStorePassword() + { + return EMPTY_PASSWORD_STRING; + } + + public String trustStoreType() + { + return "PKCS12"; + } + + public String serverKeyStorePath() + { + return serverKeyStorePath.toString(); + } + + public String serverKeyStorePassword() + { + return EMPTY_PASSWORD_STRING; + } + + public String serverKeyStoreType() + { + return "PKCS12"; + } + + public Map<String, String> mtlOptionMap() + { + if (!isEnabled()) + { + return Collections.emptyMap(); + } + + LOGGER.info("Test mTLS certificate is enabled. " + + "Will use test keystore as truststore so the client will trust the server"); + Map<String, String> optionMap = new HashMap<>(); + optionMap.put("truststore_path", trustStorePath()); + optionMap.put("truststore_password", EMPTY_PASSWORD_STRING); + optionMap.put("truststore_type", trustStoreType()); + optionMap.put("keystore_path", clientKeyStorePath.toString()); + optionMap.put("keystore_password", EMPTY_PASSWORD_STRING); + optionMap.put("keystore_type", "PKCS12"); + return optionMap; + } +} diff --git a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java index 60299ce..e216018 100644 --- a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java +++ b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java @@ -39,6 +39,7 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -72,10 +73,14 @@ import org.apache.cassandra.sidecar.common.dns.DnsResolver; import org.apache.cassandra.sidecar.common.utils.DriverUtils; import org.apache.cassandra.sidecar.common.utils.SidecarVersionProvider; import org.apache.cassandra.sidecar.config.JmxConfiguration; +import org.apache.cassandra.sidecar.config.KeyStoreConfiguration; import org.apache.cassandra.sidecar.config.ServiceConfiguration; import org.apache.cassandra.sidecar.config.SidecarConfiguration; +import org.apache.cassandra.sidecar.config.SslConfiguration; +import org.apache.cassandra.sidecar.config.yaml.KeyStoreConfigurationImpl; import org.apache.cassandra.sidecar.config.yaml.ServiceConfigurationImpl; import org.apache.cassandra.sidecar.config.yaml.SidecarConfigurationImpl; +import org.apache.cassandra.sidecar.config.yaml.SslConfigurationImpl; import org.apache.cassandra.sidecar.exceptions.ThrowableUtils; import org.apache.cassandra.sidecar.server.MainModule; import org.apache.cassandra.sidecar.server.Server; @@ -87,6 +92,7 @@ import org.apache.cassandra.testing.TestUtils; import org.apache.cassandra.testing.TestVersion; import org.apache.cassandra.testing.TestVersionSupplier; +import static org.apache.cassandra.sidecar.testing.MtlsTestHelper.CASSANDRA_INTEGRATION_TEST_ENABLE_MTLS; import static org.assertj.core.api.Assertions.assertThat; /** @@ -131,12 +137,16 @@ public abstract class SharedClusterIntegrationTestBase protected final Logger logger = LoggerFactory.getLogger(SharedClusterIntegrationTestBase.class); private static final int MAX_CLUSTER_PROVISION_RETRIES = 5; + @TempDir + static Path secretsPath; + protected Vertx vertx; protected DnsResolver dnsResolver; protected IClusterExtension<? extends IInstance> cluster; protected Server server; protected Injector injector; protected TestVersion testVersion; + protected MtlsTestHelper mtlsTestHelper; private IsolatedDTestClassLoaderWrapper classLoaderWrapper; static @@ -146,7 +156,7 @@ public abstract class SharedClusterIntegrationTestBase } @BeforeAll - protected void setup() throws InterruptedException + protected void setup() throws Exception { Optional<TestVersion> maybeTestVersion = TestVersionSupplier.testVersions().findFirst(); assertThat(maybeTestVersion).isPresent(); @@ -161,6 +171,7 @@ public abstract class SharedClusterIntegrationTestBase assertThat(cluster).isNotNull(); afterClusterProvisioned(); initializeSchemaForTest(); + mtlsTestHelper = new MtlsTestHelper(secretsPath); startSidecar(cluster); beforeTestStart(); } @@ -306,7 +317,8 @@ public abstract class SharedClusterIntegrationTestBase protected void startSidecar(ICluster<? extends IInstance> cluster) throws InterruptedException { VertxTestContext context = new VertxTestContext(); - injector = Guice.createInjector(Modules.override(new MainModule()).with(new IntegrationTestModule(cluster, classLoaderWrapper))); + AbstractModule testModule = new IntegrationTestModule(cluster, classLoaderWrapper, mtlsTestHelper); + injector = Guice.createInjector(Modules.override(new MainModule()).with(testModule)); dnsResolver = injector.getInstance(DnsResolver.class); vertx = injector.getInstance(Vertx.class); server = injector.getInstance(Server.class); @@ -455,13 +467,18 @@ public abstract class SharedClusterIntegrationTestBase static class IntegrationTestModule extends AbstractModule { + private static final Logger LOGGER = LoggerFactory.getLogger(IntegrationTestModule.class); private final ICluster<? extends IInstance> cluster; private final IsolatedDTestClassLoaderWrapper wrapper; + private final MtlsTestHelper mtlsTestHelper; - IntegrationTestModule(ICluster<? extends IInstance> cluster, IsolatedDTestClassLoaderWrapper wrapper) + IntegrationTestModule(ICluster<? extends IInstance> cluster, + IsolatedDTestClassLoaderWrapper wrapper, + MtlsTestHelper mtlsTestHelper) { this.cluster = cluster; this.wrapper = wrapper; + this.mtlsTestHelper = mtlsTestHelper; } @Provides @@ -500,8 +517,39 @@ public abstract class SharedClusterIntegrationTestBase .host("0.0.0.0") // binds to all interfaces, potential security issue if left running for long .port(0) // let the test find an available port .build(); + + + SslConfiguration sslConfiguration = null; + if (mtlsTestHelper.isEnabled()) + { + LOGGER.info("Enabling test mTLS certificate/keystore."); + + KeyStoreConfiguration truststoreConfiguration = + new KeyStoreConfigurationImpl(mtlsTestHelper.trustStorePath(), + mtlsTestHelper.trustStorePassword(), + mtlsTestHelper.trustStoreType(), + -1); + + KeyStoreConfiguration keyStoreConfiguration = + new KeyStoreConfigurationImpl(mtlsTestHelper.serverKeyStorePath(), + mtlsTestHelper.serverKeyStorePassword(), + mtlsTestHelper.serverKeyStoreType(), + -1); + + sslConfiguration = SslConfigurationImpl.builder() + .enabled(true) + .keystore(keyStoreConfiguration) + .truststore(truststoreConfiguration) + .build(); + } + else + { + LOGGER.info("Not enabling mTLS for testing purposes. Set '{}' to 'true' if you would " + + "like mTLS enabled.", CASSANDRA_INTEGRATION_TEST_ENABLE_MTLS); + } return SidecarConfigurationImpl.builder() .serviceConfiguration(conf) + .sslConfiguration(sslConfiguration) .build(); } diff --git a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/testing/utils/tls/CertificateBuilder.java b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/testing/utils/tls/CertificateBuilder.java new file mode 100644 index 0000000..0bc4671 --- /dev/null +++ b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/testing/utils/tls/CertificateBuilder.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.testing.utils.tls; + +import java.io.IOException; +import java.math.BigInteger; +import java.security.GeneralSecurityException; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.SecureRandom; +import java.security.cert.X509Certificate; +import java.security.spec.AlgorithmParameterSpec; +import java.security.spec.ECGenParameterSpec; +import java.security.spec.RSAKeyGenParameterSpec; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Objects; +import javax.security.auth.x500.X500Principal; + +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.asn1.x509.BasicConstraints; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.asn1.x509.GeneralName; +import org.bouncycastle.asn1.x509.GeneralNames; +import org.bouncycastle.cert.X509CertificateHolder; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; + +/** + * A utility class to generate certificates for tests. + * + * <p>This class is copied from the Apache Cassandra code + */ +public class CertificateBuilder +{ + private static final GeneralName[] EMPTY_SAN = {}; + private static final SecureRandom SECURE_RANDOM = new SecureRandom(); + + private boolean isCertificateAuthority; + private String alias; + private X500Name subject; + private SecureRandom random; + private Date notBefore = Date.from(Instant.now().minus(1, ChronoUnit.DAYS)); + private Date notAfter = Date.from(Instant.now().plus(1, ChronoUnit.DAYS)); + private String algorithm; + private AlgorithmParameterSpec algorithmParameterSpec; + private String signatureAlgorithm; + private BigInteger serial; + private final List<GeneralName> subjectAlternativeNames = new ArrayList<>(); + + public CertificateBuilder() + { + ecp256Algorithm(); + } + + public CertificateBuilder isCertificateAuthority(boolean isCertificateAuthority) + { + this.isCertificateAuthority = isCertificateAuthority; + return this; + } + + public CertificateBuilder subject(String subject) + { + this.subject = new X500Name(Objects.requireNonNull(subject)); + return this; + } + + public CertificateBuilder notBefore(Instant notBefore) + { + return notBefore(Date.from(Objects.requireNonNull(notBefore))); + } + + private CertificateBuilder notBefore(Date notBefore) + { + this.notBefore = Objects.requireNonNull(notBefore); + return this; + } + + public CertificateBuilder notAfter(Instant notAfter) + { + return notAfter(Date.from(Objects.requireNonNull(notAfter))); + } + + private CertificateBuilder notAfter(Date notAfter) + { + this.notAfter = Objects.requireNonNull(notAfter); + return this; + } + + public CertificateBuilder addSanUriName(String uri) + { + subjectAlternativeNames.add(new GeneralName(GeneralName.uniformResourceIdentifier, uri)); + return this; + } + + public CertificateBuilder addSanDnsName(String dnsName) + { + subjectAlternativeNames.add(new GeneralName(GeneralName.dNSName, dnsName)); + return this; + } + + public CertificateBuilder secureRandom(SecureRandom secureRandom) + { + this.random = Objects.requireNonNull(secureRandom); + return this; + } + + public CertificateBuilder alias(String alias) + { + this.alias = Objects.requireNonNull(alias); + return this; + } + + public CertificateBuilder serial(BigInteger serial) + { + this.serial = serial; + return this; + } + + public CertificateBuilder ecp256Algorithm() + { + this.algorithm = "EC"; + this.algorithmParameterSpec = new ECGenParameterSpec("secp256r1"); + this.signatureAlgorithm = "SHA256WITHECDSA"; + return this; + } + + public CertificateBuilder rsa2048Algorithm() + { + this.algorithm = "RSA"; + this.algorithmParameterSpec = new RSAKeyGenParameterSpec(2048, RSAKeyGenParameterSpec.F4); + this.signatureAlgorithm = "SHA256WITHRSA"; + return this; + } + + public CertificateBundle buildSelfSigned() throws Exception + { + KeyPair keyPair = generateKeyPair(); + + JcaX509v3CertificateBuilder builder = createCertBuilder(subject, subject, keyPair); + addExtensions(builder); + + ContentSigner signer = new JcaContentSignerBuilder(signatureAlgorithm).build(keyPair.getPrivate()); + X509CertificateHolder holder = builder.build(signer); + X509Certificate root = new JcaX509CertificateConverter().getCertificate(holder); + return new CertificateBundle(signatureAlgorithm, new X509Certificate[]{root}, root, keyPair, alias); + } + + public CertificateBundle buildIssuedBy(CertificateBundle issuer) throws Exception + { + String issuerSignAlgorithm = issuer.signatureAlgorithm(); + return buildIssuedBy(issuer, issuerSignAlgorithm); + } + + public CertificateBundle buildIssuedBy(CertificateBundle issuer, String issuerSignAlgorithm) throws Exception + { + KeyPair keyPair = generateKeyPair(); + + X500Principal issuerPrincipal = issuer.certificate().getSubjectX500Principal(); + X500Name issuerName = X500Name.getInstance(issuerPrincipal.getEncoded()); + JcaX509v3CertificateBuilder builder = createCertBuilder(issuerName, subject, keyPair); + + addExtensions(builder); + + PrivateKey issuerPrivateKey = issuer.keyPair().getPrivate(); + if (issuerPrivateKey == null) + { + throw new IllegalArgumentException("Cannot sign certificate with issuer that does not have a private key."); + } + ContentSigner signer = new JcaContentSignerBuilder(issuerSignAlgorithm).build(issuerPrivateKey); + X509CertificateHolder holder = builder.build(signer); + X509Certificate cert = new JcaX509CertificateConverter().getCertificate(holder); + X509Certificate[] issuerPath = issuer.certificatePath(); + X509Certificate[] path = new X509Certificate[issuerPath.length + 1]; + path[0] = cert; + System.arraycopy(issuerPath, 0, path, 1, issuerPath.length); + return new CertificateBundle(signatureAlgorithm, path, issuer.rootCertificate(), keyPair, alias); + } + + private SecureRandom secureRandom() + { + return random != null ? random : SECURE_RANDOM; + } + + private KeyPair generateKeyPair() throws GeneralSecurityException + { + KeyPairGenerator keyGen = KeyPairGenerator.getInstance(algorithm); + keyGen.initialize(algorithmParameterSpec, secureRandom()); + return keyGen.generateKeyPair(); + } + + private JcaX509v3CertificateBuilder createCertBuilder(X500Name issuer, X500Name subject, KeyPair keyPair) + { + BigInteger serial = this.serial != null ? this.serial : new BigInteger(159, secureRandom()); + PublicKey pubKey = keyPair.getPublic(); + return new JcaX509v3CertificateBuilder(issuer, serial, notBefore, notAfter, subject, pubKey); + } + + private void addExtensions(JcaX509v3CertificateBuilder builder) throws IOException + { + if (isCertificateAuthority) + { + builder.addExtension(Extension.basicConstraints, true, new BasicConstraints(true)); + } + + boolean criticality = false; + if (!subjectAlternativeNames.isEmpty()) + { + builder.addExtension(Extension.subjectAlternativeName, criticality, + new GeneralNames(subjectAlternativeNames.toArray(EMPTY_SAN))); + } + } +} diff --git a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/testing/utils/tls/CertificateBundle.java b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/testing/utils/tls/CertificateBundle.java new file mode 100644 index 0000000..34a6f02 --- /dev/null +++ b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/testing/utils/tls/CertificateBundle.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.testing.utils.tls; + +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.security.KeyPair; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.cert.X509Certificate; +import java.util.Objects; + +/** + * This class is copied from the Apache Cassandra code + */ +public class CertificateBundle +{ + private final String signatureAlgorithm; + private final X509Certificate[] chain; + private final X509Certificate root; + private final KeyPair keyPair; + private final String alias; + + public CertificateBundle(String signatureAlgorithm, X509Certificate[] chain, + X509Certificate root, KeyPair keyPair, String alias) + { + this.signatureAlgorithm = Objects.requireNonNull(signatureAlgorithm); + this.chain = chain; + this.root = root; + this.keyPair = keyPair; + this.alias = alias != null ? alias : "1"; + } + + public KeyStore toKeyStore(char[] keyEntryPassword) throws KeyStoreException + { + KeyStore keyStore; + try + { + keyStore = KeyStore.getInstance("PKCS12"); + keyStore.load(null, null); + } + catch (Exception e) + { + throw new RuntimeException("Failed to initialize PKCS#12 KeyStore.", e); + } + keyStore.setCertificateEntry("1", root); + if (!isCertificateAuthority()) + { + keyStore.setKeyEntry(alias, keyPair.getPrivate(), keyEntryPassword, chain); + } + return keyStore; + } + + public Path toTempKeyStorePath(Path baseDir, char[] pkcs12Password, char[] keyEntryPassword) throws Exception + { + KeyStore keyStore = toKeyStore(keyEntryPassword); + Path tempFile = Files.createTempFile(baseDir, "ks", ".p12"); + try (OutputStream out = Files.newOutputStream(tempFile, StandardOpenOption.WRITE)) + { + keyStore.store(out, pkcs12Password); + } + return tempFile; + } + + public boolean isCertificateAuthority() + { + return chain[0].getBasicConstraints() != -1; + } + + public X509Certificate certificate() + { + return chain[0]; + } + + public KeyPair keyPair() + { + return keyPair; + } + + public X509Certificate[] certificatePath() + { + return chain.clone(); + } + + public X509Certificate rootCertificate() + { + return root; + } + + public String signatureAlgorithm() + { + return signatureAlgorithm; + } +} diff --git a/cassandra-analytics-integration-tests/build.gradle b/cassandra-analytics-integration-tests/build.gradle index 3230831..a6af9f9 100644 --- a/cassandra-analytics-integration-tests/build.gradle +++ b/cassandra-analytics-integration-tests/build.gradle @@ -44,6 +44,9 @@ println("Using ${integrationMaxHeapSize} maxHeapSize") def integrationMaxParallelForks = (System.getenv("INTEGRATION_MAX_PARALLEL_FORKS") ?: "4") as int println("Using ${integrationMaxParallelForks} maxParallelForks") +def integrationEnableMtls = (System.getenv("INTEGRATION_MTLS_ENABLED") ?: "true") as boolean +println("Using mTLS for tests? ${integrationEnableMtls}") + configurations { // remove netty-all dependency coming from spark all*.exclude(group: 'io.netty', module: 'netty-all') @@ -77,6 +80,7 @@ test { System.getProperty("cassandra.sidecar.versions_to_test", "4.1") systemProperty "SKIP_STARTUP_VALIDATIONS", "true" systemProperty "logback.configurationFile", "src/test/resources/logback-test.xml" + systemProperty "cassandra.integration.sidecar.test.enable_mtls", integrationEnableMtls minHeapSize = '1g' maxHeapSize = integrationMaxHeapSize maxParallelForks = integrationMaxParallelForks diff --git a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BlockedInstancesTest.java b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BlockedInstancesTest.java index 83e1ad7..82b5d9a 100644 --- a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BlockedInstancesTest.java +++ b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BlockedInstancesTest.java @@ -20,6 +20,7 @@ package org.apache.cassandra.analytics; import java.lang.reflect.Method; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -66,9 +67,11 @@ public class BlockedInstancesTest extends ResiliencyTestBase { TestConsistencyLevel cl = TestConsistencyLevel.of(QUORUM, QUORUM); QualifiedName table = new QualifiedName(TEST_KEYSPACE, tableName(testInfo)); - bulkWriterDataFrameWriter(df, table).option(WriterOptions.BULK_WRITER_CL.name(), cl.writeCL.name()) - .option(WriterOptions.BLOCKED_CASSANDRA_INSTANCES.name(), "127.0.0.2") - .save(); + Map<String, String> additionalOptions = new HashMap<>(); + additionalOptions.put(WriterOptions.BULK_WRITER_CL.name(), cl.writeCL.name()); + additionalOptions.put(WriterOptions.BLOCKED_CASSANDRA_INSTANCES.name(), "127.0.0.2"); + + bulkWriterDataFrameWriter(df, table, additionalOptions).save(); expectedInstanceData.entrySet() .stream() .filter(e -> e.getKey().broadcastAddress().getAddress().getHostAddress().equals("127.0.0.2")) @@ -82,10 +85,12 @@ public class BlockedInstancesTest extends ResiliencyTestBase { TestConsistencyLevel cl = TestConsistencyLevel.of(ONE, ALL); QualifiedName table = new QualifiedName(TEST_KEYSPACE, tableName(testInfo)); - Throwable thrown = catchThrowable(() -> - bulkWriterDataFrameWriter(df, table).option(WriterOptions.BULK_WRITER_CL.name(), cl.writeCL.name()) - .option(WriterOptions.BLOCKED_CASSANDRA_INSTANCES.name(), "127.0.0.2") - .save()); + Throwable thrown = catchThrowable(() -> { + Map<String, String> additionalOptions = new HashMap<>(); + additionalOptions.put(WriterOptions.BULK_WRITER_CL.name(), cl.writeCL.name()); + additionalOptions.put(WriterOptions.BLOCKED_CASSANDRA_INSTANCES.name(), "127.0.0.2"); + bulkWriterDataFrameWriter(df, table, additionalOptions).save(); + }); validateFailedJob(table, cl, thrown); } @@ -94,10 +99,12 @@ public class BlockedInstancesTest extends ResiliencyTestBase { TestConsistencyLevel cl = TestConsistencyLevel.of(QUORUM, QUORUM); QualifiedName table = new QualifiedName(TEST_KEYSPACE, tableName(testInfo)); - Throwable thrown = catchThrowable(() -> - bulkWriterDataFrameWriter(df, table).option(WriterOptions.BULK_WRITER_CL.name(), cl.writeCL.name()) - .option(WriterOptions.BLOCKED_CASSANDRA_INSTANCES.name(), "127.0.0.2,127.0.0.3") - .save()); + Throwable thrown = catchThrowable(() -> { + Map<String, String> additionalOptions = new HashMap<>(); + additionalOptions.put(WriterOptions.BULK_WRITER_CL.name(), cl.writeCL.name()); + additionalOptions.put(WriterOptions.BLOCKED_CASSANDRA_INSTANCES.name(), "127.0.0.2,127.0.0.3"); + bulkWriterDataFrameWriter(df, table, additionalOptions).save(); + }); validateFailedJob(table, cl, thrown); } diff --git a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/DataGenerationUtils.java b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/DataGenerationUtils.java index 3c1897d..8ed742a 100644 --- a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/DataGenerationUtils.java +++ b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/DataGenerationUtils.java @@ -66,7 +66,7 @@ public final class DataGenerationUtils */ public static Dataset<Row> generateCourseData(SparkSession spark, int rowCount) { - return generateCourseData(spark, rowCount, false); + return generateCourseData(spark, rowCount, false, null, null); } /** @@ -84,6 +84,28 @@ public final class DataGenerationUtils * @return a {@link Dataset} with generated data */ public static Dataset<Row> generateCourseData(SparkSession spark, int rowCount, boolean udfData) + { + return generateCourseData(spark, rowCount, udfData, null, null); + } + + /** + * Generates course data with schema + * + * <pre> + * id integer, + * course string, + * marks integer + * </pre> + * + * @param spark the spark session to use + * @param rowCount the number of records to generate + * @param udfData if a field representing a User Defined Type should be added + * @param ttl (optional) a TTL value for the data frame + * @param timestamp (optional) a timestamp value for the data frame + * @return a {@link Dataset} with generated data + */ + public static Dataset<Row> generateCourseData(SparkSession spark, int rowCount, boolean udfData, + Integer ttl, Long timestamp) { SQLContext sql = spark.sqlContext(); StructType schema = new StructType() @@ -98,15 +120,33 @@ public final class DataGenerationUtils schema = schema.add("User_Defined_Type", udfType); } + if (ttl != null) + { + schema = schema.add("ttl", IntegerType, false); + } + + if (timestamp != null) + { + schema = schema.add("timestamp", LongType, false); + } + List<Row> rows = IntStream.range(0, rowCount) .mapToObj(recordNum -> { String course = "course" + recordNum; - if (!udfData) + List<Object> values = new ArrayList<>(Arrays.asList(recordNum, course, recordNum)); + if (udfData) + { + values.add(RowFactory.create(recordNum, recordNum)); + } + if (ttl != null) + { + values.add(ttl); + } + if (timestamp != null) { - return RowFactory.create(recordNum, course, recordNum); + values.add(timestamp); } - return RowFactory.create(recordNum, course, recordNum, - RowFactory.create(recordNum, recordNum)); + return RowFactory.create(values.toArray()); }).collect(Collectors.toList()); return sql.createDataFrame(rows, schema); } diff --git a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/IntegrationTestJob.java b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/IntegrationTestJob.java deleted file mode 100644 index acf4200..0000000 --- a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/IntegrationTestJob.java +++ /dev/null @@ -1,379 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.cassandra.analytics; - -import java.io.Serializable; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import java.util.stream.LongStream; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.cassandra.spark.KryoRegister; -import org.apache.cassandra.spark.bulkwriter.BulkSparkConf; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.sql.DataFrameWriter; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.types.StructType; -import org.jetbrains.annotations.NotNull; - -/** - * Spark job for use in integration tests. It contains a framework for generating data, - * writing using the bulk writer, and then reading that data back using the reader. - * It has extension points for the actual row and schema generation. - */ -public final class IntegrationTestJob implements Serializable -{ - private static final Logger LOGGER = LoggerFactory.getLogger(IntegrationTestJob.class); - @NotNull - private final RowGenerator rowGenerator; - // Transient because we only need it on the driver. - @NotNull - private final transient Function<Dataset<Row>, Dataset<Row>> postWriteDatasetModifier; - private final int rowCount; - private final int sidecarPort; - private final StructType writeSchema; - @NotNull - private final Map<String, String> extraWriterOptions; - private final boolean shouldWrite; - private final boolean shouldRead; - private final String table; - private String keyspace; - - //CHECKSTYLE IGNORE: This is only called from the Builder - private IntegrationTestJob(@NotNull RowGenerator rowGenerator, - @NotNull StructType writeSchema, - @NotNull Function<Dataset<Row>, Dataset<Row>> postWriteDatasetModifier, - @NotNull int rowCount, - @NotNull int sidecarPort, - @NotNull Map<String, String> extraWriterOptions, - boolean shouldWrite, - boolean shouldRead, String keyspace, String table) - { - this.rowGenerator = rowGenerator; - this.postWriteDatasetModifier = postWriteDatasetModifier; - this.rowCount = rowCount; - this.sidecarPort = sidecarPort; - this.writeSchema = writeSchema; - this.extraWriterOptions = extraWriterOptions; - this.shouldWrite = shouldWrite; - this.shouldRead = shouldRead; - this.keyspace = keyspace; - this.table = table; - } - - public static Builder builder(RowGenerator rowGenerator, StructType writeSchema) - { - return new Builder(rowGenerator, writeSchema); - } - - private static void checkSmallDataFrameEquality(Dataset<Row> expected, Dataset<Row> actual) - { - if (actual == null) - { - throw new NullPointerException("actual dataframe is null"); - } - if (expected == null) - { - throw new NullPointerException("expected dataframe is null"); - } - // Simulate `actual` having fewer rows, but all match rows in `expected`. - // The previous implementation would consider these equal - // actual = actual.limit(1000); - if (!actual.exceptAll(expected).isEmpty() || !expected.exceptAll(actual).isEmpty()) - { - throw new IllegalStateException("The content of the dataframes differs"); - } - } - - private Dataset<Row> write(SQLContext sql, SparkContext sc) - { - JavaSparkContext javaSparkContext = JavaSparkContext.fromSparkContext(sc); - JavaRDD<Row> rows = genDataset(javaSparkContext, sc.defaultParallelism()); - Dataset<Row> df = sql.createDataFrame(rows, writeSchema); - - DataFrameWriter<Row> dfWriter = df.write() - .format("org.apache.cassandra.spark.sparksql.CassandraDataSink") - .option("sidecar_instances", "localhost,localhost2,localhost3") - .option("sidecar_port", sidecarPort) - .option("keyspace", keyspace) - .option("table", table) - .option("local_dc", "datacenter1") - .option("bulk_writer_cl", "LOCAL_QUORUM") - .option("number_splits", "-1") - .options(extraWriterOptions) - .mode("append"); - dfWriter.save(); - return df; - } - - private Dataset<Row> read(SparkConf sparkConf, SQLContext sql, SparkContext sc) - { - int expectedRowCount = rowCount; - int coresPerExecutor = sparkConf.getInt("spark.executor.cores", 1); - int numExecutors = sparkConf.getInt("spark.dynamicAllocation.maxExecutors", sparkConf.getInt("spark.executor.instances", 1)); - int numCores = coresPerExecutor * numExecutors; - - Dataset<Row> df = sql.read().format("org.apache.cassandra.spark.sparksql.CassandraDataSource") - .option("sidecar_instances", "localhost,localhost2,localhost3") - .option("sidecar_port", sidecarPort) - .option("keyspace", "spark_test") - .option("table", "test") - .option("DC", "datacenter1") - .option("snapshotName", UUID.randomUUID().toString()) - .option("createSnapshot", "true") - // Shutdown hooks are called after the job ends, and in the case of integration tests - // the sidecar is already shut down before this. Since the cluster will be torn - // down anyway, the integration job skips clearing snapshots. - .option("clearSnapshot", "false") - .option("defaultParallelism", sc.defaultParallelism()) - .option("numCores", numCores) - .option("sizing", "default") - .load(); - - long count = df.count(); - LOGGER.info("Found {} records", count); - - if (count != expectedRowCount) - { - throw new RuntimeException(String.format("Expected %d records but found %d records", - expectedRowCount, - count)); - } - return df; - } - - private JavaRDD<Row> genDataset(JavaSparkContext sc, int parallelism) - { - long recordsPerPartition = rowCount / parallelism; - long remainder = rowCount - (recordsPerPartition * parallelism); - List<Integer> seq = IntStream.range(0, parallelism).boxed().collect(Collectors.toList()); - JavaRDD<Row> dataset = sc.parallelize(seq, parallelism).mapPartitionsWithIndex( - (Function2<Integer, Iterator<Integer>, Iterator<Row>>) (index, integerIterator) -> { - long firstRecordNumber = index * recordsPerPartition; - long recordsToGenerate = index.equals(parallelism) ? remainder : recordsPerPartition; - java.util.Iterator<Row> rows = LongStream.range(0, recordsToGenerate).mapToObj(offset -> { - long recordNumber = firstRecordNumber + offset; - return rowGenerator.rowFor(recordNumber); - }).iterator(); - return rows; - }, false); - return dataset; - } - - public void run() - { - LOGGER.info("Starting Spark job with args={}", this); - - SparkConf sparkConf = new SparkConf().setAppName("Sample Spark Cassandra Bulk Reader Job") - .set("spark.master", "local[8]"); - - // Add SBW-specific settings - // TODO: Simplify setting up spark conf - BulkSparkConf.setupSparkConf(sparkConf, true); - KryoRegister.setup(sparkConf); - - SparkSession spark = SparkSession - .builder() - .config(sparkConf) - .getOrCreate(); - SparkContext sc = spark.sparkContext(); - SQLContext sql = spark.sqlContext(); - LOGGER.info("Spark Conf: " + sparkConf.toDebugString()); - - try - { - Dataset<Row> written = null; - Dataset<Row> read = null; - if (shouldWrite) - { - written = write(sql, sc); - written = postWriteDatasetModifier.apply(written); - } - - if (shouldRead) - { - read = read(sparkConf, sql, sc); - } - if (read != null && written != null) - { - checkSmallDataFrameEquality(written, read); - } - LOGGER.info("Finished Spark job, shutting down..."); - sc.stop(); - } - catch (Throwable throwable) - { - LOGGER.error("Unexpected exception executing Spark job", throwable); - try - { - sc.stop(); - } - catch (Throwable ignored) - { - } - throw throwable; // rethrow so the exception bubbles up to test usages. - } - } - - @Override - public String toString() - { - return "IntegrationTestJob: { \n " + - " rowCount:%d,\n" + - " parallelism:%d,\n" + - " ttl:%d,\n" + - " timestamp:%d,\n" + - " sidecarPort:%d\n" + - "}"; - } - - @FunctionalInterface - public interface RowGenerator extends Serializable - { - Row rowFor(long recordNumber); - } - - public static class Builder - { - private final RowGenerator rowGenerator; - private final StructType writeSchema; - private int rowCount = 10_000; - private int sidecarPort = 9043; - private Map<String, String> extraWriterOptions; - private Function<Dataset<Row>, Dataset<Row>> postWriteDatasetModifier = Function.identity(); - private boolean shouldWrite = true; - private boolean shouldRead = true; - private String keyspace = "spark_test"; - private String table = "test"; - - Builder(@NotNull RowGenerator rowGenerator, StructType writeSchema) - { - this.rowGenerator = rowGenerator; - this.writeSchema = writeSchema; - } - - public Builder withKeyspace(String keyspace) - { - return update(builder -> builder.keyspace = keyspace); - } - - public Builder withTable(String table) - { - return update(builder -> builder.table = table); - } - - public Builder withRowCount(int rowCount) - { - return update(builder -> builder.rowCount = rowCount); - } - - public Builder withSidecarPort(int sidecarPort) - { - return update(builder -> builder.sidecarPort = sidecarPort); - } - - private Builder update(Consumer<Builder> update) - { - update.accept(this); - return this; - } - - public Builder withExtraWriterOptions(Map<String, String> writerOptions) - { - return update(builder -> builder.extraWriterOptions = writerOptions); - } - - public Builder withPostWriteDatasetModifier(Function<Dataset<Row>, Dataset<Row>> dataSetModifier) - { - return update(builder -> builder.postWriteDatasetModifier = dataSetModifier); - } - - - public Builder shouldWrite(boolean shouldWrite) - { - return update(builder -> builder.shouldWrite = shouldWrite); - } - - public Builder shouldRead(boolean shouldRead) - { - return update(builder -> builder.shouldRead = shouldRead); - } - - - public IntegrationTestJob build() - { - return new IntegrationTestJob(this.rowGenerator, - this.writeSchema, - this.postWriteDatasetModifier, - this.rowCount, - this.sidecarPort, - this.extraWriterOptions, - this.shouldWrite, - this.shouldRead, - this.keyspace, - this.table); - } - - public void run() - { - build().run(); - } - } - - /** - * An example of a postWriteDatasetModifier. - * This function will remove `ttl` and `timestamp` columns from the dataframe after write, as - * the read part of the integration test job doesn't read ttl and timestamp columns, we need to remove them - * from the Dataset after it's saved so the final comparison works. - * @param addedTTLColumn if the "ttl" column was added to the dataframe. - * @param addedTimestampColumn if the "timestamp" column was added to the dataframe. - * @return the modified dataset - */ - @SuppressWarnings("unused") - public static Function<Dataset<Row>, Dataset<Row>> ttlRemovalModifier(boolean addedTTLColumn, boolean addedTimestampColumn) - { - return (Dataset<Row> df) -> { - if (addedTTLColumn) - { - df = df.drop("ttl"); - } - if (addedTimestampColumn) - { - df = df.drop("timestamp"); - } - return df; - }; - } -} diff --git a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java index d92a737..daf7469 100644 --- a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java +++ b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java @@ -19,8 +19,10 @@ package org.apache.cassandra.analytics; +import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; @@ -68,7 +70,7 @@ public abstract class SharedClusterSparkIntegrationTestBase extends SharedCluste protected void beforeTestStart() { super.beforeTestStart(); - sparkTestUtils.initialize(cluster.delegate(), dnsResolver, server.actualPort()); + sparkTestUtils.initialize(cluster.delegate(), dnsResolver, server.actualPort(), mtlsTestHelper); } @Override @@ -101,7 +103,22 @@ public abstract class SharedClusterSparkIntegrationTestBase extends SharedCluste */ protected DataFrameWriter<Row> bulkWriterDataFrameWriter(Dataset<Row> df, QualifiedName tableName) { - return sparkTestUtils.defaultBulkWriterDataFrameWriter(df, tableName); + return sparkTestUtils.defaultBulkWriterDataFrameWriter(df, tableName, Collections.emptyMap()); + } + + /** + * A preconfigured {@link DataFrameWriter} with pre-populated required options that can be overridden + * with additional options for every specific test. + * + * @param df the source dataframe to write + * @param tableName the qualified name for the Cassandra table + * @param additionalOptions additional options for the data frame + * @return a {@link DataFrameWriter} for Cassandra bulk writes + */ + protected DataFrameWriter<Row> bulkWriterDataFrameWriter(Dataset<Row> df, QualifiedName tableName, + Map<String, String> additionalOptions) + { + return sparkTestUtils.defaultBulkWriterDataFrameWriter(df, tableName, additionalOptions); } protected SparkConf getOrCreateSparkConf() @@ -135,6 +152,25 @@ public abstract class SharedClusterSparkIntegrationTestBase extends SharedCluste return bridge; } + public void checkSmallDataFrameEquality(Dataset<Row> expected, Dataset<Row> actual) + { + if (actual == null) + { + throw new NullPointerException("actual dataframe is null"); + } + if (expected == null) + { + throw new NullPointerException("expected dataframe is null"); + } + // Simulate `actual` having fewer rows, but all match rows in `expected`. + // The previous implementation would consider these equal + // actual = actual.limit(1000); + if (!actual.exceptAll(expected).isEmpty() || !expected.exceptAll(actual).isEmpty()) + { + throw new IllegalStateException("The content of the dataframes differs"); + } + } + public void validateWritesWithDriverResultSet(List<Row> sourceData, ResultSet queriedData, Function<com.datastax.driver.core.Row, String> rowFormatter) { diff --git a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkBulkWriterSimpleTest.java b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkBulkWriterSimpleTest.java index ebb6860..3c7f156 100644 --- a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkBulkWriterSimpleTest.java +++ b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkBulkWriterSimpleTest.java @@ -19,9 +19,6 @@ package org.apache.cassandra.analytics; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.function.Function; @@ -36,61 +33,57 @@ import org.apache.cassandra.spark.bulkwriter.WriterOptions; import org.apache.cassandra.testing.ClusterBuilderConfiguration; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructType; -import org.jetbrains.annotations.NotNull; +import org.apache.spark.sql.SparkSession; +import static org.apache.cassandra.testing.TestUtils.CREATE_TEST_TABLE_STATEMENT; import static org.apache.cassandra.testing.TestUtils.DC1_RF3; +import static org.apache.cassandra.testing.TestUtils.ROW_COUNT; import static org.apache.cassandra.testing.TestUtils.TEST_KEYSPACE; -import static org.apache.spark.sql.types.DataTypes.BinaryType; -import static org.apache.spark.sql.types.DataTypes.IntegerType; -import static org.apache.spark.sql.types.DataTypes.LongType; class SparkBulkWriterSimpleTest extends SharedClusterSparkIntegrationTestBase { - static final String CREATE_TABLE_STATEMENT = "CREATE TABLE %s (id BIGINT PRIMARY KEY, course BLOB, marks BIGINT)"; + static final QualifiedName QUALIFIED_NAME = new QualifiedName(TEST_KEYSPACE, "test"); @ParameterizedTest @MethodSource("options") void runSampleJob(Integer ttl, Long timestamp) { Map<String, String> writerOptions = new HashMap<>(); - - boolean addTTLColumn = ttl != null; - boolean addTimestampColumn = timestamp != null; - if (addTTLColumn) + if (ttl != null) { writerOptions.put(WriterOptions.TTL.name(), "ttl"); } - - if (addTimestampColumn) + if (timestamp != null) { writerOptions.put(WriterOptions.TIMESTAMP.name(), "timestamp"); } - IntegrationTestJob.builder((recordNum) -> generateCourse(recordNum, ttl, timestamp), - getWriteSchema(addTTLColumn, addTimestampColumn)) - .withSidecarPort(server.actualPort()) - .withExtraWriterOptions(writerOptions) - .withPostWriteDatasetModifier(writeToReadDfFunc(addTTLColumn, addTimestampColumn)) - .run(); - } + SparkSession spark = getOrCreateSparkSession(); - static Stream<Arguments> options() - { - return Stream.of( - Arguments.of(null, null), - Arguments.of(1000, null), - Arguments.of(null, 1432815430948567L), - Arguments.of(1000, 1432815430948567L) - ); + // Generate some data + Dataset<Row> dfWrite = DataGenerationUtils.generateCourseData(spark, ROW_COUNT, false, ttl, timestamp); + + // Write the data using Bulk Writer + bulkWriterDataFrameWriter(dfWrite, QUALIFIED_NAME, writerOptions).save(); + + // Validate using CQL + sparkTestUtils.validateWrites(dfWrite.collectAsList(), queryAllData(QUALIFIED_NAME)); + + // Remove columns from write DF to perform validations + Dataset<Row> written = writeToReadDfFunc(ttl != null, timestamp != null).apply(dfWrite); + + // Read data back using Bulk Reader + Dataset<Row> read = bulkReaderDataFrame(QUALIFIED_NAME).load(); + + // Validate that written and read dataframes are the same + checkSmallDataFrameEquality(written, read); } @Override protected void initializeSchemaForTest() { createTestKeyspace(TEST_KEYSPACE, DC1_RF3); - createTestTable(new QualifiedName(TEST_KEYSPACE, "test"), CREATE_TABLE_STATEMENT); + createTestTable(QUALIFIED_NAME, CREATE_TEST_TABLE_STATEMENT); } @Override @@ -100,9 +93,19 @@ class SparkBulkWriterSimpleTest extends SharedClusterSparkIntegrationTestBase .nodesPerDc(3); } + static Stream<Arguments> options() + { + return Stream.of( + Arguments.of(null, null), + Arguments.of(1000, null), + Arguments.of(null, 1432815430948567L), + Arguments.of(1000, 1432815430948567L) + ); + } + // Because the read part of the integration test job doesn't read ttl and timestamp columns, we need to remove them // from the Dataset after it's saved. - private Function<Dataset<Row>, Dataset<Row>> writeToReadDfFunc(boolean addedTTLColumn, boolean addedTimestampColumn) + static Function<Dataset<Row>, Dataset<Row>> writeToReadDfFunc(boolean addedTTLColumn, boolean addedTimestampColumn) { return (Dataset<Row> df) -> { if (addedTTLColumn) @@ -116,53 +119,4 @@ class SparkBulkWriterSimpleTest extends SharedClusterSparkIntegrationTestBase return df; }; } - - @SuppressWarnings("SameParameterValue") - private static StructType getWriteSchema(boolean addTTLColumn, boolean addTimestampColumn) - { - StructType schema = new StructType() - .add("id", LongType, false) - .add("course", BinaryType, false) - .add("marks", LongType, false); - if (addTTLColumn) - { - schema = schema.add("ttl", IntegerType, false); - } - if (addTimestampColumn) - { - schema = schema.add("timestamp", LongType, false); - } - return schema; - } - - @NotNull - @SuppressWarnings("SameParameterValue") - private static Row generateCourse(long recordNumber, Integer ttl, Long timestamp) - { - String courseNameString = String.valueOf(recordNumber); - int courseNameStringLen = courseNameString.length(); - int courseNameMultiplier = 1000 / courseNameStringLen; - byte[] courseName = dupStringAsBytes(courseNameString, courseNameMultiplier); - ArrayList<Object> values = new ArrayList<>(Arrays.asList(recordNumber, courseName, recordNumber)); - if (ttl != null) - { - values.add(ttl); - } - if (timestamp != null) - { - values.add(timestamp); - } - return RowFactory.create(values.toArray()); - } - - private static byte[] dupStringAsBytes(String string, Integer times) - { - byte[] stringBytes = string.getBytes(); - ByteBuffer buffer = ByteBuffer.allocate(stringBytes.length * times); - for (int time = 0; time < times; time++) - { - buffer.put(stringBytes); - } - return buffer.array(); - } } diff --git a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkTestUtils.java b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkTestUtils.java index 46b70a9..a4a4734 100644 --- a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkTestUtils.java +++ b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkTestUtils.java @@ -22,6 +22,7 @@ package org.apache.cassandra.analytics; import java.net.UnknownHostException; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.UUID; @@ -32,6 +33,7 @@ import org.apache.cassandra.distributed.api.ICluster; import org.apache.cassandra.distributed.api.IInstance; import org.apache.cassandra.distributed.shared.JMXUtil; import org.apache.cassandra.sidecar.common.dns.DnsResolver; +import org.apache.cassandra.sidecar.testing.MtlsTestHelper; import org.apache.cassandra.sidecar.testing.QualifiedName; import org.apache.cassandra.spark.KryoRegister; import org.apache.cassandra.spark.bulkwriter.BulkSparkConf; @@ -52,8 +54,9 @@ import static org.assertj.core.api.Assertions.assertThat; public class SparkTestUtils { protected ICluster<? extends IInstance> cluster; - private DnsResolver dnsResolver; - private int sidecarPort; + protected DnsResolver dnsResolver; + protected int sidecarPort; + protected MtlsTestHelper mtlsTestHelper; /** * Runs any initialization code required for the tests @@ -62,10 +65,12 @@ public class SparkTestUtils * @param dnsResolver the DNS resolver used to lookup replicas * @param sidecarPort the port where Sidecar is running */ - public void initialize(ICluster<? extends IInstance> cluster, DnsResolver dnsResolver, int sidecarPort) + public void initialize(ICluster<? extends IInstance> cluster, DnsResolver dnsResolver, int sidecarPort, + MtlsTestHelper mtlsTestHelper) { this.cluster = Objects.requireNonNull(cluster, "cluster is required"); this.dnsResolver = Objects.requireNonNull(dnsResolver, "dnsResolver is required"); + this.mtlsTestHelper = Objects.requireNonNull(mtlsTestHelper, "mtlsTestHelper is required"); this.sidecarPort = sidecarPort; } @@ -110,6 +115,7 @@ public class SparkTestUtils .option("defaultParallelism", sc.defaultParallelism()) .option("numCores", numCores) .option("sizing", "default") + .options(mtlsTestHelper.mtlOptionMap()) .option("sidecar_port", sidecarPort); } @@ -117,12 +123,13 @@ public class SparkTestUtils * Returns a {@link DataFrameWriter<Row>} with default options for performing a bulk write test, including * required parameters. * - * @param df the source data frame - * @param tableName the qualified name of the table + * @param df the source data frame + * @param tableName the qualified name of the table + * @param additionalOptions additional options for the data frame * @return a {@link DataFrameWriter<Row>} with default options for performing a bulk write test */ - public DataFrameWriter<Row> defaultBulkWriterDataFrameWriter(Dataset<Row> df, - QualifiedName tableName) + public DataFrameWriter<Row> defaultBulkWriterDataFrameWriter(Dataset<Row> df, QualifiedName tableName, + Map<String, String> additionalOptions) { return df.write() .format("org.apache.cassandra.spark.sparksql.CassandraDataSink") @@ -133,6 +140,8 @@ public class SparkTestUtils .option("bulk_writer_cl", "LOCAL_QUORUM") .option("number_splits", "-1") .option("sidecar_port", sidecarPort) + .options(additionalOptions) + .options(mtlsTestHelper.mtlOptionMap()) .mode("append"); } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org For additional commands, e-mail: commits-h...@cassandra.apache.org