This is an automated email from the ASF dual-hosted git repository.

frankgh pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-analytics.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 6901018  CASSANDRA-19526: Optionally enable TLS in the server and 
client for Analytics testing
6901018 is described below

commit 690101840d4d8f9c656bb0ca114f6619af80e1cf
Author: Francisco Guerrero <fran...@apache.org>
AuthorDate: Mon Apr 8 14:33:50 2024 -0700

    CASSANDRA-19526: Optionally enable TLS in the server and client for 
Analytics testing
    
    All integration tests today run without TLS, which is generally fine 
because they run locally. However,
    it is helpful to be able to start up the sidecar with TLS enabled in the 
integration test framework so
    that third-party tests could connect via secure connections for testing 
purposes.
    
    Co-authored-by: Doug Rohrer <droh...@apple.com>
    Co-authored-by: Francisco Guerrero <fran...@apache.org>
    
    Patch by Doug Rohrer, Francisco Guerrero; Reviewed by Yifan Cai for 
CASSANDRA-19526
---
 .../spark/common/stats/JobStatsPublisher.java      |   2 +
 .../build.gradle                                   |   6 +-
 .../distributed/impl/CassandraCluster.java         |   1 +
 .../cassandra/sidecar/testing/MtlsTestHelper.java  | 146 ++++++++
 .../testing/SharedClusterIntegrationTestBase.java  |  54 ++-
 .../testing/utils/tls/CertificateBuilder.java      | 236 +++++++++++++
 .../testing/utils/tls/CertificateBundle.java       | 112 ++++++
 cassandra-analytics-integration-tests/build.gradle |   4 +
 .../cassandra/analytics/BlockedInstancesTest.java  |  29 +-
 .../cassandra/analytics/DataGenerationUtils.java   |  50 ++-
 .../cassandra/analytics/IntegrationTestJob.java    | 379 ---------------------
 .../SharedClusterSparkIntegrationTestBase.java     |  40 ++-
 .../analytics/SparkBulkWriterSimpleTest.java       | 118 ++-----
 .../apache/cassandra/analytics/SparkTestUtils.java |  23 +-
 14 files changed, 710 insertions(+), 490 deletions(-)

diff --git 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/stats/JobStatsPublisher.java
 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/stats/JobStatsPublisher.java
index 28643e8..9027ce4 100644
--- 
a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/stats/JobStatsPublisher.java
+++ 
b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/stats/JobStatsPublisher.java
@@ -30,6 +30,8 @@ public interface JobStatsPublisher
 {
     /**
      * Publish the job attributes to be persisted and summarized
+     *
+     * @param stats the stats to publish
      */
     void publish(Map<String, String> stats);
 }
diff --git a/cassandra-analytics-integration-framework/build.gradle 
b/cassandra-analytics-integration-framework/build.gradle
index aeba617..63d66c7 100644
--- a/cassandra-analytics-integration-framework/build.gradle
+++ b/cassandra-analytics-integration-framework/build.gradle
@@ -75,7 +75,11 @@ dependencies {
         exclude group: 'junit', module: 'junit'
     }
     implementation("io.vertx:vertx-web-client:${project.vertxVersion}")
-    implementation group: 'com.fasterxml.jackson.core', name: 
'jackson-annotations', version: '2.14.3'
+    implementation(group: 'com.fasterxml.jackson.core', name: 
'jackson-annotations', version: '2.14.3')
+
+    // Bouncycastle dependencies for test certificate provisioning
+    implementation(group: 'org.bouncycastle', name: 'bcprov-jdk18on', version: 
'1.78')
+    implementation(group: 'org.bouncycastle', name: 'bcpkix-jdk18on', version: 
'1.78')
 
     testImplementation(platform("org.junit:junit-bom:${project.junitVersion}"))
     testImplementation('org.junit.jupiter:junit-jupiter')
diff --git 
a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/distributed/impl/CassandraCluster.java
 
b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/distributed/impl/CassandraCluster.java
index 4d20c62..f5a5abd 100644
--- 
a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/distributed/impl/CassandraCluster.java
+++ 
b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/distributed/impl/CassandraCluster.java
@@ -68,6 +68,7 @@ public class CassandraCluster<I extends IInstance> implements 
IClusterExtension<
         // java.lang.IllegalStateException: Can't load <CLASS>. Instance class 
loader is already closed.
         return 
className.equals("org.apache.cassandra.utils.concurrent.Ref$OnLeak")
                || 
className.startsWith("org.apache.cassandra.metrics.RestorableMeter")
+               || 
className.equals("org.apache.logging.slf4j.EventDataConverter")
                || (className.startsWith("org.apache.cassandra.analytics.") && 
className.contains("BBHelper"));
     };
 
diff --git 
a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/MtlsTestHelper.java
 
b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/MtlsTestHelper.java
new file mode 100644
index 0000000..55f200f
--- /dev/null
+++ 
b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/MtlsTestHelper.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.sidecar.testing;
+
+import java.nio.file.Path;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.cassandra.testing.utils.tls.CertificateBuilder;
+import org.apache.cassandra.testing.utils.tls.CertificateBundle;
+
+/**
+ * A class that encapsulates testing with Mutual TLS.
+ */
+public class MtlsTestHelper
+{
+    private static final Logger LOGGER = 
LoggerFactory.getLogger(MtlsTestHelper.class);
+    public static final char[] EMPTY_PASSWORD = new char[0];
+    public static final String EMPTY_PASSWORD_STRING = "";
+    /**
+     * A system property that can enable / disable testing with Mutual TLS
+     */
+    public static final String CASSANDRA_INTEGRATION_TEST_ENABLE_MTLS = 
"cassandra.integration.sidecar.test.enable_mtls";
+    private final boolean enableMtlsForTesting;
+    CertificateBundle certificateAuthority;
+    Path truststorePath;
+    Path serverKeyStorePath;
+    Path clientKeyStorePath;
+
+    public MtlsTestHelper(Path secretsPath) throws Exception
+    {
+        this(secretsPath, 
System.getProperty(CASSANDRA_INTEGRATION_TEST_ENABLE_MTLS, 
"false").equals("true"));
+    }
+
+    public MtlsTestHelper(Path secretsPath, boolean enableMtlsForTesting) 
throws Exception
+    {
+        this.enableMtlsForTesting = enableMtlsForTesting;
+        maybeInitializeSecrets(Objects.requireNonNull(secretsPath, 
"secretsPath cannot be null"));
+    }
+
+    void maybeInitializeSecrets(Path secretsPath) throws Exception
+    {
+        if (!enableMtlsForTesting)
+        {
+            return;
+        }
+
+        certificateAuthority =
+        new CertificateBuilder().subject("CN=Apache Cassandra Root CA, 
OU=Certification Authority, O=Unknown, C=Unknown")
+                                .alias("fakerootca")
+                                .isCertificateAuthority(true)
+                                .buildSelfSigned();
+        truststorePath = certificateAuthority.toTempKeyStorePath(secretsPath, 
EMPTY_PASSWORD, EMPTY_PASSWORD);
+
+        CertificateBuilder serverKeyStoreBuilder =
+        new CertificateBuilder().subject("CN=Apache Cassandra, OU=mtls_test, 
O=Unknown, L=Unknown, ST=Unknown, C=Unknown")
+                                .addSanDnsName("localhost");
+        // Add SANs for every potential hostname Sidecar will serve
+        for (int i = 1; i <= 20; i++)
+        {
+            serverKeyStoreBuilder.addSanDnsName("localhost" + i);
+        }
+
+        CertificateBundle serverKeyStore = 
serverKeyStoreBuilder.buildIssuedBy(certificateAuthority);
+        serverKeyStorePath = serverKeyStore.toTempKeyStorePath(secretsPath, 
EMPTY_PASSWORD, EMPTY_PASSWORD);
+        CertificateBundle clientKeyStore = new 
CertificateBuilder().subject("CN=Apache Cassandra, OU=mtls_test, O=Unknown, 
L=Unknown, ST=Unknown, C=Unknown")
+                                                                   
.addSanDnsName("localhost")
+                                                                   
.buildIssuedBy(certificateAuthority);
+        clientKeyStorePath = clientKeyStore.toTempKeyStorePath(secretsPath, 
EMPTY_PASSWORD, EMPTY_PASSWORD);
+    }
+
+    public boolean isEnabled()
+    {
+        return enableMtlsForTesting;
+    }
+
+    public String trustStorePath()
+    {
+        return truststorePath.toString();
+    }
+
+    public String trustStorePassword()
+    {
+        return EMPTY_PASSWORD_STRING;
+    }
+
+    public String trustStoreType()
+    {
+        return "PKCS12";
+    }
+
+    public String serverKeyStorePath()
+    {
+        return serverKeyStorePath.toString();
+    }
+
+    public String serverKeyStorePassword()
+    {
+        return EMPTY_PASSWORD_STRING;
+    }
+
+    public String serverKeyStoreType()
+    {
+        return "PKCS12";
+    }
+
+    public Map<String, String> mtlOptionMap()
+    {
+        if (!isEnabled())
+        {
+            return Collections.emptyMap();
+        }
+
+        LOGGER.info("Test mTLS certificate is enabled. "
+                    + "Will use test keystore as truststore so the client will 
trust the server");
+        Map<String, String> optionMap = new HashMap<>();
+        optionMap.put("truststore_path", trustStorePath());
+        optionMap.put("truststore_password", EMPTY_PASSWORD_STRING);
+        optionMap.put("truststore_type", trustStoreType());
+        optionMap.put("keystore_path", clientKeyStorePath.toString());
+        optionMap.put("keystore_password", EMPTY_PASSWORD_STRING);
+        optionMap.put("keystore_type", "PKCS12");
+        return optionMap;
+    }
+}
diff --git 
a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java
 
b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java
index 60299ce..e216018 100644
--- 
a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java
+++ 
b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java
@@ -39,6 +39,7 @@ import org.junit.jupiter.api.AfterAll;
 import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.TestInstance;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.junit.jupiter.api.io.TempDir;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -72,10 +73,14 @@ import org.apache.cassandra.sidecar.common.dns.DnsResolver;
 import org.apache.cassandra.sidecar.common.utils.DriverUtils;
 import org.apache.cassandra.sidecar.common.utils.SidecarVersionProvider;
 import org.apache.cassandra.sidecar.config.JmxConfiguration;
+import org.apache.cassandra.sidecar.config.KeyStoreConfiguration;
 import org.apache.cassandra.sidecar.config.ServiceConfiguration;
 import org.apache.cassandra.sidecar.config.SidecarConfiguration;
+import org.apache.cassandra.sidecar.config.SslConfiguration;
+import org.apache.cassandra.sidecar.config.yaml.KeyStoreConfigurationImpl;
 import org.apache.cassandra.sidecar.config.yaml.ServiceConfigurationImpl;
 import org.apache.cassandra.sidecar.config.yaml.SidecarConfigurationImpl;
+import org.apache.cassandra.sidecar.config.yaml.SslConfigurationImpl;
 import org.apache.cassandra.sidecar.exceptions.ThrowableUtils;
 import org.apache.cassandra.sidecar.server.MainModule;
 import org.apache.cassandra.sidecar.server.Server;
@@ -87,6 +92,7 @@ import org.apache.cassandra.testing.TestUtils;
 import org.apache.cassandra.testing.TestVersion;
 import org.apache.cassandra.testing.TestVersionSupplier;
 
+import static 
org.apache.cassandra.sidecar.testing.MtlsTestHelper.CASSANDRA_INTEGRATION_TEST_ENABLE_MTLS;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /**
@@ -131,12 +137,16 @@ public abstract class SharedClusterIntegrationTestBase
     protected final Logger logger = 
LoggerFactory.getLogger(SharedClusterIntegrationTestBase.class);
     private static final int MAX_CLUSTER_PROVISION_RETRIES = 5;
 
+    @TempDir
+    static Path secretsPath;
+
     protected Vertx vertx;
     protected DnsResolver dnsResolver;
     protected IClusterExtension<? extends IInstance> cluster;
     protected Server server;
     protected Injector injector;
     protected TestVersion testVersion;
+    protected MtlsTestHelper mtlsTestHelper;
     private IsolatedDTestClassLoaderWrapper classLoaderWrapper;
 
     static
@@ -146,7 +156,7 @@ public abstract class SharedClusterIntegrationTestBase
     }
 
     @BeforeAll
-    protected void setup() throws InterruptedException
+    protected void setup() throws Exception
     {
         Optional<TestVersion> maybeTestVersion = 
TestVersionSupplier.testVersions().findFirst();
         assertThat(maybeTestVersion).isPresent();
@@ -161,6 +171,7 @@ public abstract class SharedClusterIntegrationTestBase
         assertThat(cluster).isNotNull();
         afterClusterProvisioned();
         initializeSchemaForTest();
+        mtlsTestHelper = new MtlsTestHelper(secretsPath);
         startSidecar(cluster);
         beforeTestStart();
     }
@@ -306,7 +317,8 @@ public abstract class SharedClusterIntegrationTestBase
     protected void startSidecar(ICluster<? extends IInstance> cluster) throws 
InterruptedException
     {
         VertxTestContext context = new VertxTestContext();
-        injector = Guice.createInjector(Modules.override(new 
MainModule()).with(new IntegrationTestModule(cluster, classLoaderWrapper)));
+        AbstractModule testModule = new IntegrationTestModule(cluster, 
classLoaderWrapper, mtlsTestHelper);
+        injector = Guice.createInjector(Modules.override(new 
MainModule()).with(testModule));
         dnsResolver = injector.getInstance(DnsResolver.class);
         vertx = injector.getInstance(Vertx.class);
         server = injector.getInstance(Server.class);
@@ -455,13 +467,18 @@ public abstract class SharedClusterIntegrationTestBase
 
     static class IntegrationTestModule extends AbstractModule
     {
+        private static final Logger LOGGER = 
LoggerFactory.getLogger(IntegrationTestModule.class);
         private final ICluster<? extends IInstance> cluster;
         private final IsolatedDTestClassLoaderWrapper wrapper;
+        private final MtlsTestHelper mtlsTestHelper;
 
-        IntegrationTestModule(ICluster<? extends IInstance> cluster, 
IsolatedDTestClassLoaderWrapper wrapper)
+        IntegrationTestModule(ICluster<? extends IInstance> cluster,
+                              IsolatedDTestClassLoaderWrapper wrapper,
+                              MtlsTestHelper mtlsTestHelper)
         {
             this.cluster = cluster;
             this.wrapper = wrapper;
+            this.mtlsTestHelper = mtlsTestHelper;
         }
 
         @Provides
@@ -500,8 +517,39 @@ public abstract class SharedClusterIntegrationTestBase
                                                                 
.host("0.0.0.0") // binds to all interfaces, potential security issue if left 
running for long
                                                                 .port(0) // 
let the test find an available port
                                                                 .build();
+
+
+            SslConfiguration sslConfiguration = null;
+            if (mtlsTestHelper.isEnabled())
+            {
+                LOGGER.info("Enabling test mTLS certificate/keystore.");
+
+                KeyStoreConfiguration truststoreConfiguration =
+                new KeyStoreConfigurationImpl(mtlsTestHelper.trustStorePath(),
+                                              
mtlsTestHelper.trustStorePassword(),
+                                              mtlsTestHelper.trustStoreType(),
+                                              -1);
+
+                KeyStoreConfiguration keyStoreConfiguration =
+                new 
KeyStoreConfigurationImpl(mtlsTestHelper.serverKeyStorePath(),
+                                              
mtlsTestHelper.serverKeyStorePassword(),
+                                              
mtlsTestHelper.serverKeyStoreType(),
+                                              -1);
+
+                sslConfiguration = SslConfigurationImpl.builder()
+                                                       .enabled(true)
+                                                       
.keystore(keyStoreConfiguration)
+                                                       
.truststore(truststoreConfiguration)
+                                                       .build();
+            }
+            else
+            {
+                LOGGER.info("Not enabling mTLS for testing purposes. Set '{}' 
to 'true' if you would " +
+                            "like mTLS enabled.", 
CASSANDRA_INTEGRATION_TEST_ENABLE_MTLS);
+            }
             return SidecarConfigurationImpl.builder()
                                            .serviceConfiguration(conf)
+                                           .sslConfiguration(sslConfiguration)
                                            .build();
         }
 
diff --git 
a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/testing/utils/tls/CertificateBuilder.java
 
b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/testing/utils/tls/CertificateBuilder.java
new file mode 100644
index 0000000..0bc4671
--- /dev/null
+++ 
b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/testing/utils/tls/CertificateBuilder.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.testing.utils.tls;
+
+import java.io.IOException;
+import java.math.BigInteger;
+import java.security.GeneralSecurityException;
+import java.security.KeyPair;
+import java.security.KeyPairGenerator;
+import java.security.PrivateKey;
+import java.security.PublicKey;
+import java.security.SecureRandom;
+import java.security.cert.X509Certificate;
+import java.security.spec.AlgorithmParameterSpec;
+import java.security.spec.ECGenParameterSpec;
+import java.security.spec.RSAKeyGenParameterSpec;
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.List;
+import java.util.Objects;
+import javax.security.auth.x500.X500Principal;
+
+import org.bouncycastle.asn1.x500.X500Name;
+import org.bouncycastle.asn1.x509.BasicConstraints;
+import org.bouncycastle.asn1.x509.Extension;
+import org.bouncycastle.asn1.x509.GeneralName;
+import org.bouncycastle.asn1.x509.GeneralNames;
+import org.bouncycastle.cert.X509CertificateHolder;
+import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
+import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder;
+import org.bouncycastle.operator.ContentSigner;
+import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
+
+/**
+ * A utility class to generate certificates for tests.
+ *
+ * <p>This class is copied from the Apache Cassandra code
+ */
+public class CertificateBuilder
+{
+    private static final GeneralName[] EMPTY_SAN = {};
+    private static final SecureRandom SECURE_RANDOM = new SecureRandom();
+
+    private boolean isCertificateAuthority;
+    private String alias;
+    private X500Name subject;
+    private SecureRandom random;
+    private Date notBefore = Date.from(Instant.now().minus(1, 
ChronoUnit.DAYS));
+    private Date notAfter = Date.from(Instant.now().plus(1, ChronoUnit.DAYS));
+    private String algorithm;
+    private AlgorithmParameterSpec algorithmParameterSpec;
+    private String signatureAlgorithm;
+    private BigInteger serial;
+    private final List<GeneralName> subjectAlternativeNames = new 
ArrayList<>();
+
+    public CertificateBuilder()
+    {
+        ecp256Algorithm();
+    }
+
+    public CertificateBuilder isCertificateAuthority(boolean 
isCertificateAuthority)
+    {
+        this.isCertificateAuthority = isCertificateAuthority;
+        return this;
+    }
+
+    public CertificateBuilder subject(String subject)
+    {
+        this.subject = new X500Name(Objects.requireNonNull(subject));
+        return this;
+    }
+
+    public CertificateBuilder notBefore(Instant notBefore)
+    {
+        return notBefore(Date.from(Objects.requireNonNull(notBefore)));
+    }
+
+    private CertificateBuilder notBefore(Date notBefore)
+    {
+        this.notBefore = Objects.requireNonNull(notBefore);
+        return this;
+    }
+
+    public CertificateBuilder notAfter(Instant notAfter)
+    {
+        return notAfter(Date.from(Objects.requireNonNull(notAfter)));
+    }
+
+    private CertificateBuilder notAfter(Date notAfter)
+    {
+        this.notAfter = Objects.requireNonNull(notAfter);
+        return this;
+    }
+
+    public CertificateBuilder addSanUriName(String uri)
+    {
+        subjectAlternativeNames.add(new 
GeneralName(GeneralName.uniformResourceIdentifier, uri));
+        return this;
+    }
+
+    public CertificateBuilder addSanDnsName(String dnsName)
+    {
+        subjectAlternativeNames.add(new GeneralName(GeneralName.dNSName, 
dnsName));
+        return this;
+    }
+
+    public CertificateBuilder secureRandom(SecureRandom secureRandom)
+    {
+        this.random = Objects.requireNonNull(secureRandom);
+        return this;
+    }
+
+    public CertificateBuilder alias(String alias)
+    {
+        this.alias = Objects.requireNonNull(alias);
+        return this;
+    }
+
+    public CertificateBuilder serial(BigInteger serial)
+    {
+        this.serial = serial;
+        return this;
+    }
+
+    public CertificateBuilder ecp256Algorithm()
+    {
+        this.algorithm = "EC";
+        this.algorithmParameterSpec = new ECGenParameterSpec("secp256r1");
+        this.signatureAlgorithm = "SHA256WITHECDSA";
+        return this;
+    }
+
+    public CertificateBuilder rsa2048Algorithm()
+    {
+        this.algorithm = "RSA";
+        this.algorithmParameterSpec = new RSAKeyGenParameterSpec(2048, 
RSAKeyGenParameterSpec.F4);
+        this.signatureAlgorithm = "SHA256WITHRSA";
+        return this;
+    }
+
+    public CertificateBundle buildSelfSigned() throws Exception
+    {
+        KeyPair keyPair = generateKeyPair();
+
+        JcaX509v3CertificateBuilder builder = createCertBuilder(subject, 
subject, keyPair);
+        addExtensions(builder);
+
+        ContentSigner signer = new 
JcaContentSignerBuilder(signatureAlgorithm).build(keyPair.getPrivate());
+        X509CertificateHolder holder = builder.build(signer);
+        X509Certificate root = new 
JcaX509CertificateConverter().getCertificate(holder);
+        return new CertificateBundle(signatureAlgorithm, new 
X509Certificate[]{root}, root, keyPair, alias);
+    }
+
+    public CertificateBundle buildIssuedBy(CertificateBundle issuer) throws 
Exception
+    {
+        String issuerSignAlgorithm = issuer.signatureAlgorithm();
+        return buildIssuedBy(issuer, issuerSignAlgorithm);
+    }
+
+    public CertificateBundle buildIssuedBy(CertificateBundle issuer, String 
issuerSignAlgorithm) throws Exception
+    {
+        KeyPair keyPair = generateKeyPair();
+
+        X500Principal issuerPrincipal = 
issuer.certificate().getSubjectX500Principal();
+        X500Name issuerName = 
X500Name.getInstance(issuerPrincipal.getEncoded());
+        JcaX509v3CertificateBuilder builder = createCertBuilder(issuerName, 
subject, keyPair);
+
+        addExtensions(builder);
+
+        PrivateKey issuerPrivateKey = issuer.keyPair().getPrivate();
+        if (issuerPrivateKey == null)
+        {
+            throw new IllegalArgumentException("Cannot sign certificate with 
issuer that does not have a private key.");
+        }
+        ContentSigner signer = new 
JcaContentSignerBuilder(issuerSignAlgorithm).build(issuerPrivateKey);
+        X509CertificateHolder holder = builder.build(signer);
+        X509Certificate cert = new 
JcaX509CertificateConverter().getCertificate(holder);
+        X509Certificate[] issuerPath = issuer.certificatePath();
+        X509Certificate[] path = new X509Certificate[issuerPath.length + 1];
+        path[0] = cert;
+        System.arraycopy(issuerPath, 0, path, 1, issuerPath.length);
+        return new CertificateBundle(signatureAlgorithm, path, 
issuer.rootCertificate(), keyPair, alias);
+    }
+
+    private SecureRandom secureRandom()
+    {
+        return random != null ? random : SECURE_RANDOM;
+    }
+
+    private KeyPair generateKeyPair() throws GeneralSecurityException
+    {
+        KeyPairGenerator keyGen = KeyPairGenerator.getInstance(algorithm);
+        keyGen.initialize(algorithmParameterSpec, secureRandom());
+        return keyGen.generateKeyPair();
+    }
+
+    private JcaX509v3CertificateBuilder createCertBuilder(X500Name issuer, 
X500Name subject, KeyPair keyPair)
+    {
+        BigInteger serial = this.serial != null ? this.serial : new 
BigInteger(159, secureRandom());
+        PublicKey pubKey = keyPair.getPublic();
+        return new JcaX509v3CertificateBuilder(issuer, serial, notBefore, 
notAfter, subject, pubKey);
+    }
+
+    private void addExtensions(JcaX509v3CertificateBuilder builder) throws 
IOException
+    {
+        if (isCertificateAuthority)
+        {
+            builder.addExtension(Extension.basicConstraints, true, new 
BasicConstraints(true));
+        }
+
+        boolean criticality = false;
+        if (!subjectAlternativeNames.isEmpty())
+        {
+            builder.addExtension(Extension.subjectAlternativeName, criticality,
+                                 new 
GeneralNames(subjectAlternativeNames.toArray(EMPTY_SAN)));
+        }
+    }
+}
diff --git 
a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/testing/utils/tls/CertificateBundle.java
 
b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/testing/utils/tls/CertificateBundle.java
new file mode 100644
index 0000000..34a6f02
--- /dev/null
+++ 
b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/testing/utils/tls/CertificateBundle.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.testing.utils.tls;
+
+import java.io.OutputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardOpenOption;
+import java.security.KeyPair;
+import java.security.KeyStore;
+import java.security.KeyStoreException;
+import java.security.cert.X509Certificate;
+import java.util.Objects;
+
+/**
+ * This class is copied from the Apache Cassandra code
+ */
+public class CertificateBundle
+{
+    private final String signatureAlgorithm;
+    private final X509Certificate[] chain;
+    private final X509Certificate root;
+    private final KeyPair keyPair;
+    private final String alias;
+
+    public CertificateBundle(String signatureAlgorithm, X509Certificate[] 
chain,
+                             X509Certificate root, KeyPair keyPair, String 
alias)
+    {
+        this.signatureAlgorithm = Objects.requireNonNull(signatureAlgorithm);
+        this.chain = chain;
+        this.root = root;
+        this.keyPair = keyPair;
+        this.alias = alias != null ? alias : "1";
+    }
+
+    public KeyStore toKeyStore(char[] keyEntryPassword) throws 
KeyStoreException
+    {
+        KeyStore keyStore;
+        try
+        {
+            keyStore = KeyStore.getInstance("PKCS12");
+            keyStore.load(null, null);
+        }
+        catch (Exception e)
+        {
+            throw new RuntimeException("Failed to initialize PKCS#12 
KeyStore.", e);
+        }
+        keyStore.setCertificateEntry("1", root);
+        if (!isCertificateAuthority())
+        {
+            keyStore.setKeyEntry(alias, keyPair.getPrivate(), 
keyEntryPassword, chain);
+        }
+        return keyStore;
+    }
+
+    public Path toTempKeyStorePath(Path baseDir, char[] pkcs12Password, char[] 
keyEntryPassword) throws Exception
+    {
+        KeyStore keyStore = toKeyStore(keyEntryPassword);
+        Path tempFile = Files.createTempFile(baseDir, "ks", ".p12");
+        try (OutputStream out = Files.newOutputStream(tempFile, 
StandardOpenOption.WRITE))
+        {
+            keyStore.store(out, pkcs12Password);
+        }
+        return tempFile;
+    }
+
+    public boolean isCertificateAuthority()
+    {
+        return chain[0].getBasicConstraints() != -1;
+    }
+
+    public X509Certificate certificate()
+    {
+        return chain[0];
+    }
+
+    public KeyPair keyPair()
+    {
+        return keyPair;
+    }
+
+    public X509Certificate[] certificatePath()
+    {
+        return chain.clone();
+    }
+
+    public X509Certificate rootCertificate()
+    {
+        return root;
+    }
+
+    public String signatureAlgorithm()
+    {
+        return signatureAlgorithm;
+    }
+}
diff --git a/cassandra-analytics-integration-tests/build.gradle 
b/cassandra-analytics-integration-tests/build.gradle
index 3230831..a6af9f9 100644
--- a/cassandra-analytics-integration-tests/build.gradle
+++ b/cassandra-analytics-integration-tests/build.gradle
@@ -44,6 +44,9 @@ println("Using ${integrationMaxHeapSize} maxHeapSize")
 def integrationMaxParallelForks = 
(System.getenv("INTEGRATION_MAX_PARALLEL_FORKS") ?: "4") as int
 println("Using ${integrationMaxParallelForks} maxParallelForks")
 
+def integrationEnableMtls = (System.getenv("INTEGRATION_MTLS_ENABLED") ?: 
"true") as boolean
+println("Using mTLS for tests? ${integrationEnableMtls}")
+
 configurations {
     // remove netty-all dependency coming from spark
     all*.exclude(group: 'io.netty', module: 'netty-all')
@@ -77,6 +80,7 @@ test {
             System.getProperty("cassandra.sidecar.versions_to_test", "4.1")
     systemProperty "SKIP_STARTUP_VALIDATIONS", "true"
     systemProperty "logback.configurationFile", 
"src/test/resources/logback-test.xml"
+    systemProperty "cassandra.integration.sidecar.test.enable_mtls", 
integrationEnableMtls
     minHeapSize = '1g'
     maxHeapSize = integrationMaxHeapSize
     maxParallelForks = integrationMaxParallelForks
diff --git 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BlockedInstancesTest.java
 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BlockedInstancesTest.java
index 83e1ad7..82b5d9a 100644
--- 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BlockedInstancesTest.java
+++ 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BlockedInstancesTest.java
@@ -20,6 +20,7 @@ package org.apache.cassandra.analytics;
 
 import java.lang.reflect.Method;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
@@ -66,9 +67,11 @@ public class BlockedInstancesTest extends ResiliencyTestBase
     {
         TestConsistencyLevel cl = TestConsistencyLevel.of(QUORUM, QUORUM);
         QualifiedName table = new QualifiedName(TEST_KEYSPACE, 
tableName(testInfo));
-        bulkWriterDataFrameWriter(df, 
table).option(WriterOptions.BULK_WRITER_CL.name(), cl.writeCL.name())
-                                            
.option(WriterOptions.BLOCKED_CASSANDRA_INSTANCES.name(), "127.0.0.2")
-                                            .save();
+        Map<String, String> additionalOptions = new HashMap<>();
+        additionalOptions.put(WriterOptions.BULK_WRITER_CL.name(), 
cl.writeCL.name());
+        
additionalOptions.put(WriterOptions.BLOCKED_CASSANDRA_INSTANCES.name(), 
"127.0.0.2");
+
+        bulkWriterDataFrameWriter(df, table, additionalOptions).save();
         expectedInstanceData.entrySet()
                             .stream()
                             .filter(e -> 
e.getKey().broadcastAddress().getAddress().getHostAddress().equals("127.0.0.2"))
@@ -82,10 +85,12 @@ public class BlockedInstancesTest extends ResiliencyTestBase
     {
         TestConsistencyLevel cl = TestConsistencyLevel.of(ONE, ALL);
         QualifiedName table = new QualifiedName(TEST_KEYSPACE, 
tableName(testInfo));
-        Throwable thrown = catchThrowable(() ->
-                                          bulkWriterDataFrameWriter(df, 
table).option(WriterOptions.BULK_WRITER_CL.name(), cl.writeCL.name())
-                                                                              
.option(WriterOptions.BLOCKED_CASSANDRA_INSTANCES.name(), "127.0.0.2")
-                                                                              
.save());
+        Throwable thrown = catchThrowable(() -> {
+            Map<String, String> additionalOptions = new HashMap<>();
+            additionalOptions.put(WriterOptions.BULK_WRITER_CL.name(), 
cl.writeCL.name());
+            
additionalOptions.put(WriterOptions.BLOCKED_CASSANDRA_INSTANCES.name(), 
"127.0.0.2");
+            bulkWriterDataFrameWriter(df, table, additionalOptions).save();
+        });
         validateFailedJob(table, cl, thrown);
     }
 
@@ -94,10 +99,12 @@ public class BlockedInstancesTest extends ResiliencyTestBase
     {
         TestConsistencyLevel cl = TestConsistencyLevel.of(QUORUM, QUORUM);
         QualifiedName table = new QualifiedName(TEST_KEYSPACE, 
tableName(testInfo));
-        Throwable thrown = catchThrowable(() ->
-                                          bulkWriterDataFrameWriter(df, 
table).option(WriterOptions.BULK_WRITER_CL.name(), cl.writeCL.name())
-                                                                              
.option(WriterOptions.BLOCKED_CASSANDRA_INSTANCES.name(), "127.0.0.2,127.0.0.3")
-                                                                              
.save());
+        Throwable thrown = catchThrowable(() -> {
+            Map<String, String> additionalOptions = new HashMap<>();
+            additionalOptions.put(WriterOptions.BULK_WRITER_CL.name(), 
cl.writeCL.name());
+            
additionalOptions.put(WriterOptions.BLOCKED_CASSANDRA_INSTANCES.name(), 
"127.0.0.2,127.0.0.3");
+            bulkWriterDataFrameWriter(df, table, additionalOptions).save();
+        });
         validateFailedJob(table, cl, thrown);
     }
 
diff --git 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/DataGenerationUtils.java
 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/DataGenerationUtils.java
index 3c1897d..8ed742a 100644
--- 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/DataGenerationUtils.java
+++ 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/DataGenerationUtils.java
@@ -66,7 +66,7 @@ public final class DataGenerationUtils
      */
     public static Dataset<Row> generateCourseData(SparkSession spark, int 
rowCount)
     {
-        return generateCourseData(spark, rowCount, false);
+        return generateCourseData(spark, rowCount, false, null, null);
     }
 
     /**
@@ -84,6 +84,28 @@ public final class DataGenerationUtils
      * @return a {@link Dataset} with generated data
      */
     public static Dataset<Row> generateCourseData(SparkSession spark, int 
rowCount, boolean udfData)
+    {
+        return generateCourseData(spark, rowCount, udfData, null, null);
+    }
+
+    /**
+     * Generates course data with schema
+     *
+     * <pre>
+     *     id integer,
+     *     course string,
+     *     marks integer
+     * </pre>
+     *
+     * @param spark     the spark session to use
+     * @param rowCount  the number of records to generate
+     * @param udfData   if a field representing a User Defined Type should be 
added
+     * @param ttl       (optional) a TTL value for the data frame
+     * @param timestamp (optional) a timestamp value for the data frame
+     * @return a {@link Dataset} with generated data
+     */
+    public static Dataset<Row> generateCourseData(SparkSession spark, int 
rowCount, boolean udfData,
+                                                  Integer ttl, Long timestamp)
     {
         SQLContext sql = spark.sqlContext();
         StructType schema = new StructType()
@@ -98,15 +120,33 @@ public final class DataGenerationUtils
             schema = schema.add("User_Defined_Type", udfType);
         }
 
+        if (ttl != null)
+        {
+            schema = schema.add("ttl", IntegerType, false);
+        }
+
+        if (timestamp != null)
+        {
+            schema = schema.add("timestamp", LongType, false);
+        }
+
         List<Row> rows = IntStream.range(0, rowCount)
                                   .mapToObj(recordNum -> {
                                       String course = "course" + recordNum;
-                                      if (!udfData)
+                                      List<Object> values = new 
ArrayList<>(Arrays.asList(recordNum, course, recordNum));
+                                      if (udfData)
+                                      {
+                                          
values.add(RowFactory.create(recordNum, recordNum));
+                                      }
+                                      if (ttl != null)
+                                      {
+                                          values.add(ttl);
+                                      }
+                                      if (timestamp != null)
                                       {
-                                          return RowFactory.create(recordNum, 
course, recordNum);
+                                          values.add(timestamp);
                                       }
-                                      return RowFactory.create(recordNum, 
course, recordNum,
-                                                               
RowFactory.create(recordNum, recordNum));
+                                      return 
RowFactory.create(values.toArray());
                                   }).collect(Collectors.toList());
         return sql.createDataFrame(rows, schema);
     }
diff --git 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/IntegrationTestJob.java
 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/IntegrationTestJob.java
deleted file mode 100644
index acf4200..0000000
--- 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/IntegrationTestJob.java
+++ /dev/null
@@ -1,379 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.cassandra.analytics;
-
-import java.io.Serializable;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.UUID;
-import java.util.function.Consumer;
-import java.util.function.Function;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import java.util.stream.LongStream;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import org.apache.cassandra.spark.KryoRegister;
-import org.apache.cassandra.spark.bulkwriter.BulkSparkConf;
-import org.apache.spark.SparkConf;
-import org.apache.spark.SparkContext;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.sql.DataFrameWriter;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.types.StructType;
-import org.jetbrains.annotations.NotNull;
-
-/**
- * Spark job for use in integration tests. It contains a framework for 
generating data,
- * writing using the bulk writer, and then reading that data back using the 
reader.
- * It has extension points for the actual row and schema generation.
- */
-public final class IntegrationTestJob implements Serializable
-{
-    private static final Logger LOGGER = 
LoggerFactory.getLogger(IntegrationTestJob.class);
-    @NotNull
-    private final RowGenerator rowGenerator;
-    // Transient because we only need it on the driver.
-    @NotNull
-    private final transient Function<Dataset<Row>, Dataset<Row>> 
postWriteDatasetModifier;
-    private final int rowCount;
-    private final int sidecarPort;
-    private final StructType writeSchema;
-    @NotNull
-    private final Map<String, String> extraWriterOptions;
-    private final boolean shouldWrite;
-    private final boolean shouldRead;
-    private final String table;
-    private String keyspace;
-
-    //CHECKSTYLE IGNORE: This is only called from the Builder
-    private IntegrationTestJob(@NotNull RowGenerator rowGenerator,
-                               @NotNull StructType writeSchema,
-                               @NotNull Function<Dataset<Row>, Dataset<Row>> 
postWriteDatasetModifier,
-                               @NotNull int rowCount,
-                               @NotNull int sidecarPort,
-                               @NotNull Map<String, String> extraWriterOptions,
-                               boolean shouldWrite,
-                               boolean shouldRead, String keyspace, String 
table)
-    {
-        this.rowGenerator = rowGenerator;
-        this.postWriteDatasetModifier = postWriteDatasetModifier;
-        this.rowCount = rowCount;
-        this.sidecarPort = sidecarPort;
-        this.writeSchema = writeSchema;
-        this.extraWriterOptions = extraWriterOptions;
-        this.shouldWrite = shouldWrite;
-        this.shouldRead = shouldRead;
-        this.keyspace = keyspace;
-        this.table = table;
-    }
-
-    public static Builder builder(RowGenerator rowGenerator, StructType 
writeSchema)
-    {
-        return new Builder(rowGenerator, writeSchema);
-    }
-
-    private static void checkSmallDataFrameEquality(Dataset<Row> expected, 
Dataset<Row> actual)
-    {
-        if (actual == null)
-        {
-            throw new NullPointerException("actual dataframe is null");
-        }
-        if (expected == null)
-        {
-            throw new NullPointerException("expected dataframe is null");
-        }
-        // Simulate `actual` having fewer rows, but all match rows in 
`expected`.
-        // The previous implementation would consider these equal
-        // actual = actual.limit(1000);
-        if (!actual.exceptAll(expected).isEmpty() || 
!expected.exceptAll(actual).isEmpty())
-        {
-            throw new IllegalStateException("The content of the dataframes 
differs");
-        }
-    }
-
-    private Dataset<Row> write(SQLContext sql, SparkContext sc)
-    {
-        JavaSparkContext javaSparkContext = 
JavaSparkContext.fromSparkContext(sc);
-        JavaRDD<Row> rows = genDataset(javaSparkContext, 
sc.defaultParallelism());
-        Dataset<Row> df = sql.createDataFrame(rows, writeSchema);
-
-        DataFrameWriter<Row> dfWriter = df.write()
-                                          
.format("org.apache.cassandra.spark.sparksql.CassandraDataSink")
-                                          .option("sidecar_instances", 
"localhost,localhost2,localhost3")
-                                          .option("sidecar_port", sidecarPort)
-                                          .option("keyspace", keyspace)
-                                          .option("table", table)
-                                          .option("local_dc", "datacenter1")
-                                          .option("bulk_writer_cl", 
"LOCAL_QUORUM")
-                                          .option("number_splits", "-1")
-                                          .options(extraWriterOptions)
-                                          .mode("append");
-        dfWriter.save();
-        return df;
-    }
-
-    private Dataset<Row> read(SparkConf sparkConf, SQLContext sql, 
SparkContext sc)
-    {
-        int expectedRowCount = rowCount;
-        int coresPerExecutor = sparkConf.getInt("spark.executor.cores", 1);
-        int numExecutors = 
sparkConf.getInt("spark.dynamicAllocation.maxExecutors", 
sparkConf.getInt("spark.executor.instances", 1));
-        int numCores = coresPerExecutor * numExecutors;
-
-        Dataset<Row> df = 
sql.read().format("org.apache.cassandra.spark.sparksql.CassandraDataSource")
-                             .option("sidecar_instances", 
"localhost,localhost2,localhost3")
-                             .option("sidecar_port", sidecarPort)
-                             .option("keyspace", "spark_test")
-                             .option("table", "test")
-                             .option("DC", "datacenter1")
-                             .option("snapshotName", 
UUID.randomUUID().toString())
-                             .option("createSnapshot", "true")
-                             // Shutdown hooks are called after the job ends, 
and in the case of integration tests
-                             // the sidecar is already shut down before this. 
Since the cluster will be torn
-                             // down anyway, the integration job skips 
clearing snapshots.
-                             .option("clearSnapshot", "false")
-                             .option("defaultParallelism", 
sc.defaultParallelism())
-                             .option("numCores", numCores)
-                             .option("sizing", "default")
-                             .load();
-
-        long count = df.count();
-        LOGGER.info("Found {} records", count);
-
-        if (count != expectedRowCount)
-        {
-            throw new RuntimeException(String.format("Expected %d records but 
found %d records",
-                                                     expectedRowCount,
-                                                     count));
-        }
-        return df;
-    }
-
-    private JavaRDD<Row> genDataset(JavaSparkContext sc, int parallelism)
-    {
-        long recordsPerPartition = rowCount / parallelism;
-        long remainder = rowCount - (recordsPerPartition * parallelism);
-        List<Integer> seq = IntStream.range(0, 
parallelism).boxed().collect(Collectors.toList());
-        JavaRDD<Row> dataset = sc.parallelize(seq, 
parallelism).mapPartitionsWithIndex(
-        (Function2<Integer, Iterator<Integer>, Iterator<Row>>) (index, 
integerIterator) -> {
-            long firstRecordNumber = index * recordsPerPartition;
-            long recordsToGenerate = index.equals(parallelism) ? remainder : 
recordsPerPartition;
-            java.util.Iterator<Row> rows = LongStream.range(0, 
recordsToGenerate).mapToObj(offset -> {
-                long recordNumber = firstRecordNumber + offset;
-                return rowGenerator.rowFor(recordNumber);
-            }).iterator();
-            return rows;
-        }, false);
-        return dataset;
-    }
-
-    public void run()
-    {
-        LOGGER.info("Starting Spark job with args={}", this);
-
-        SparkConf sparkConf = new SparkConf().setAppName("Sample Spark 
Cassandra Bulk Reader Job")
-                                             .set("spark.master", "local[8]");
-
-        // Add SBW-specific settings
-        // TODO: Simplify setting up spark conf
-        BulkSparkConf.setupSparkConf(sparkConf, true);
-        KryoRegister.setup(sparkConf);
-
-        SparkSession spark = SparkSession
-                             .builder()
-                             .config(sparkConf)
-                             .getOrCreate();
-        SparkContext sc = spark.sparkContext();
-        SQLContext sql = spark.sqlContext();
-        LOGGER.info("Spark Conf: " + sparkConf.toDebugString());
-
-        try
-        {
-            Dataset<Row> written = null;
-            Dataset<Row> read = null;
-            if (shouldWrite)
-            {
-                written = write(sql, sc);
-                written = postWriteDatasetModifier.apply(written);
-            }
-
-            if (shouldRead)
-            {
-                read = read(sparkConf, sql, sc);
-            }
-            if (read != null && written != null)
-            {
-                checkSmallDataFrameEquality(written, read);
-            }
-            LOGGER.info("Finished Spark job, shutting down...");
-            sc.stop();
-        }
-        catch (Throwable throwable)
-        {
-            LOGGER.error("Unexpected exception executing Spark job", 
throwable);
-            try
-            {
-                sc.stop();
-            }
-            catch (Throwable ignored)
-            {
-            }
-            throw throwable; // rethrow so the exception bubbles up to test 
usages.
-        }
-    }
-
-    @Override
-    public String toString()
-    {
-        return "IntegrationTestJob: { \n " +
-               "    rowCount:%d,\n" +
-               "    parallelism:%d,\n" +
-               "    ttl:%d,\n" +
-               "    timestamp:%d,\n" +
-               "    sidecarPort:%d\n" +
-               "}";
-    }
-
-    @FunctionalInterface
-    public interface RowGenerator extends Serializable
-    {
-        Row rowFor(long recordNumber);
-    }
-
-    public static class Builder
-    {
-        private final RowGenerator rowGenerator;
-        private final StructType writeSchema;
-        private int rowCount = 10_000;
-        private int sidecarPort = 9043;
-        private Map<String, String> extraWriterOptions;
-        private Function<Dataset<Row>, Dataset<Row>> postWriteDatasetModifier 
= Function.identity();
-        private boolean shouldWrite = true;
-        private boolean shouldRead = true;
-        private String keyspace = "spark_test";
-        private String table = "test";
-
-        Builder(@NotNull RowGenerator rowGenerator, StructType writeSchema)
-        {
-            this.rowGenerator = rowGenerator;
-            this.writeSchema = writeSchema;
-        }
-
-        public Builder withKeyspace(String keyspace)
-        {
-            return update(builder -> builder.keyspace = keyspace);
-        }
-
-        public Builder withTable(String table)
-        {
-            return update(builder -> builder.table = table);
-        }
-
-        public Builder withRowCount(int rowCount)
-        {
-            return update(builder -> builder.rowCount = rowCount);
-        }
-
-        public Builder withSidecarPort(int sidecarPort)
-        {
-            return update(builder -> builder.sidecarPort = sidecarPort);
-        }
-
-        private Builder update(Consumer<Builder> update)
-        {
-            update.accept(this);
-            return this;
-        }
-
-        public Builder withExtraWriterOptions(Map<String, String> 
writerOptions)
-        {
-            return update(builder -> builder.extraWriterOptions = 
writerOptions);
-        }
-
-        public Builder withPostWriteDatasetModifier(Function<Dataset<Row>, 
Dataset<Row>> dataSetModifier)
-        {
-            return update(builder -> builder.postWriteDatasetModifier = 
dataSetModifier);
-        }
-
-
-        public Builder shouldWrite(boolean shouldWrite)
-        {
-            return update(builder -> builder.shouldWrite = shouldWrite);
-        }
-
-        public Builder shouldRead(boolean shouldRead)
-        {
-            return update(builder -> builder.shouldRead = shouldRead);
-        }
-
-
-        public IntegrationTestJob build()
-        {
-            return new IntegrationTestJob(this.rowGenerator,
-                                          this.writeSchema,
-                                          this.postWriteDatasetModifier,
-                                          this.rowCount,
-                                          this.sidecarPort,
-                                          this.extraWriterOptions,
-                                          this.shouldWrite,
-                                          this.shouldRead,
-                                          this.keyspace,
-                                          this.table);
-        }
-
-        public void run()
-        {
-            build().run();
-        }
-    }
-
-    /**
-     * An example of a postWriteDatasetModifier.
-     * This function will remove `ttl` and `timestamp` columns from the 
dataframe after write, as
-     * the read part of the integration test job doesn't read ttl and 
timestamp columns, we need to remove them
-     * from the Dataset after it's saved so the final comparison works.
-     * @param addedTTLColumn if the "ttl" column was added to the dataframe.
-     * @param addedTimestampColumn if the "timestamp" column was added to the 
dataframe.
-     * @return the modified dataset
-     */
-    @SuppressWarnings("unused")
-    public static Function<Dataset<Row>, Dataset<Row>> 
ttlRemovalModifier(boolean addedTTLColumn, boolean addedTimestampColumn)
-    {
-        return (Dataset<Row> df) -> {
-            if (addedTTLColumn)
-            {
-                df = df.drop("ttl");
-            }
-            if (addedTimestampColumn)
-            {
-                df = df.drop("timestamp");
-            }
-            return df;
-        };
-    }
-}
diff --git 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java
 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java
index d92a737..daf7469 100644
--- 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java
+++ 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java
@@ -19,8 +19,10 @@
 
 package org.apache.cassandra.analytics;
 
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.function.Function;
 import java.util.stream.Collectors;
@@ -68,7 +70,7 @@ public abstract class SharedClusterSparkIntegrationTestBase 
extends SharedCluste
     protected void beforeTestStart()
     {
         super.beforeTestStart();
-        sparkTestUtils.initialize(cluster.delegate(), dnsResolver, 
server.actualPort());
+        sparkTestUtils.initialize(cluster.delegate(), dnsResolver, 
server.actualPort(), mtlsTestHelper);
     }
 
     @Override
@@ -101,7 +103,22 @@ public abstract class 
SharedClusterSparkIntegrationTestBase extends SharedCluste
      */
     protected DataFrameWriter<Row> bulkWriterDataFrameWriter(Dataset<Row> df, 
QualifiedName tableName)
     {
-        return sparkTestUtils.defaultBulkWriterDataFrameWriter(df, tableName);
+        return sparkTestUtils.defaultBulkWriterDataFrameWriter(df, tableName, 
Collections.emptyMap());
+    }
+
+    /**
+     * A preconfigured {@link DataFrameWriter} with pre-populated required 
options that can be overridden
+     * with additional options for every specific test.
+     *
+     * @param df                the source dataframe to write
+     * @param tableName         the qualified name for the Cassandra table
+     * @param additionalOptions additional options for the data frame
+     * @return a {@link DataFrameWriter} for Cassandra bulk writes
+     */
+    protected DataFrameWriter<Row> bulkWriterDataFrameWriter(Dataset<Row> df, 
QualifiedName tableName,
+                                                             Map<String, 
String> additionalOptions)
+    {
+        return sparkTestUtils.defaultBulkWriterDataFrameWriter(df, tableName, 
additionalOptions);
     }
 
     protected SparkConf getOrCreateSparkConf()
@@ -135,6 +152,25 @@ public abstract class 
SharedClusterSparkIntegrationTestBase extends SharedCluste
         return bridge;
     }
 
+    public void checkSmallDataFrameEquality(Dataset<Row> expected, 
Dataset<Row> actual)
+    {
+        if (actual == null)
+        {
+            throw new NullPointerException("actual dataframe is null");
+        }
+        if (expected == null)
+        {
+            throw new NullPointerException("expected dataframe is null");
+        }
+        // Simulate `actual` having fewer rows, but all match rows in 
`expected`.
+        // The previous implementation would consider these equal
+        // actual = actual.limit(1000);
+        if (!actual.exceptAll(expected).isEmpty() || 
!expected.exceptAll(actual).isEmpty())
+        {
+            throw new IllegalStateException("The content of the dataframes 
differs");
+        }
+    }
+
     public void validateWritesWithDriverResultSet(List<Row> sourceData, 
ResultSet queriedData,
                                                   
Function<com.datastax.driver.core.Row, String> rowFormatter)
     {
diff --git 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkBulkWriterSimpleTest.java
 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkBulkWriterSimpleTest.java
index ebb6860..3c7f156 100644
--- 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkBulkWriterSimpleTest.java
+++ 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkBulkWriterSimpleTest.java
@@ -19,9 +19,6 @@
 
 package org.apache.cassandra.analytics;
 
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.function.Function;
@@ -36,61 +33,57 @@ import org.apache.cassandra.spark.bulkwriter.WriterOptions;
 import org.apache.cassandra.testing.ClusterBuilderConfiguration;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.types.StructType;
-import org.jetbrains.annotations.NotNull;
+import org.apache.spark.sql.SparkSession;
 
+import static 
org.apache.cassandra.testing.TestUtils.CREATE_TEST_TABLE_STATEMENT;
 import static org.apache.cassandra.testing.TestUtils.DC1_RF3;
+import static org.apache.cassandra.testing.TestUtils.ROW_COUNT;
 import static org.apache.cassandra.testing.TestUtils.TEST_KEYSPACE;
-import static org.apache.spark.sql.types.DataTypes.BinaryType;
-import static org.apache.spark.sql.types.DataTypes.IntegerType;
-import static org.apache.spark.sql.types.DataTypes.LongType;
 
 class SparkBulkWriterSimpleTest extends SharedClusterSparkIntegrationTestBase
 {
-    static final String CREATE_TABLE_STATEMENT = "CREATE TABLE %s (id BIGINT 
PRIMARY KEY, course BLOB, marks BIGINT)";
+    static final QualifiedName QUALIFIED_NAME = new 
QualifiedName(TEST_KEYSPACE, "test");
 
     @ParameterizedTest
     @MethodSource("options")
     void runSampleJob(Integer ttl, Long timestamp)
     {
         Map<String, String> writerOptions = new HashMap<>();
-
-        boolean addTTLColumn = ttl != null;
-        boolean addTimestampColumn = timestamp != null;
-        if (addTTLColumn)
+        if (ttl != null)
         {
             writerOptions.put(WriterOptions.TTL.name(), "ttl");
         }
-
-        if (addTimestampColumn)
+        if (timestamp != null)
         {
             writerOptions.put(WriterOptions.TIMESTAMP.name(), "timestamp");
         }
 
-        IntegrationTestJob.builder((recordNum) -> generateCourse(recordNum, 
ttl, timestamp),
-                                   getWriteSchema(addTTLColumn, 
addTimestampColumn))
-                          .withSidecarPort(server.actualPort())
-                          .withExtraWriterOptions(writerOptions)
-                          
.withPostWriteDatasetModifier(writeToReadDfFunc(addTTLColumn, 
addTimestampColumn))
-                          .run();
-    }
+        SparkSession spark = getOrCreateSparkSession();
 
-    static Stream<Arguments> options()
-    {
-        return Stream.of(
-        Arguments.of(null, null),
-        Arguments.of(1000, null),
-        Arguments.of(null, 1432815430948567L),
-        Arguments.of(1000, 1432815430948567L)
-        );
+        // Generate some data
+        Dataset<Row> dfWrite = DataGenerationUtils.generateCourseData(spark, 
ROW_COUNT, false, ttl, timestamp);
+
+        // Write the data using Bulk Writer
+        bulkWriterDataFrameWriter(dfWrite, QUALIFIED_NAME, 
writerOptions).save();
+
+        // Validate using CQL
+        sparkTestUtils.validateWrites(dfWrite.collectAsList(), 
queryAllData(QUALIFIED_NAME));
+
+        // Remove columns from write DF to perform validations
+        Dataset<Row> written = writeToReadDfFunc(ttl != null, timestamp != 
null).apply(dfWrite);
+
+        // Read data back using Bulk Reader
+        Dataset<Row> read = bulkReaderDataFrame(QUALIFIED_NAME).load();
+
+        // Validate that written and read dataframes are the same
+        checkSmallDataFrameEquality(written, read);
     }
 
     @Override
     protected void initializeSchemaForTest()
     {
         createTestKeyspace(TEST_KEYSPACE, DC1_RF3);
-        createTestTable(new QualifiedName(TEST_KEYSPACE, "test"), 
CREATE_TABLE_STATEMENT);
+        createTestTable(QUALIFIED_NAME, CREATE_TEST_TABLE_STATEMENT);
     }
 
     @Override
@@ -100,9 +93,19 @@ class SparkBulkWriterSimpleTest extends 
SharedClusterSparkIntegrationTestBase
                     .nodesPerDc(3);
     }
 
+    static Stream<Arguments> options()
+    {
+        return Stream.of(
+        Arguments.of(null, null),
+        Arguments.of(1000, null),
+        Arguments.of(null, 1432815430948567L),
+        Arguments.of(1000, 1432815430948567L)
+        );
+    }
+
     // Because the read part of the integration test job doesn't read ttl and 
timestamp columns, we need to remove them
     // from the Dataset after it's saved.
-    private Function<Dataset<Row>, Dataset<Row>> writeToReadDfFunc(boolean 
addedTTLColumn, boolean addedTimestampColumn)
+    static Function<Dataset<Row>, Dataset<Row>> writeToReadDfFunc(boolean 
addedTTLColumn, boolean addedTimestampColumn)
     {
         return (Dataset<Row> df) -> {
             if (addedTTLColumn)
@@ -116,53 +119,4 @@ class SparkBulkWriterSimpleTest extends 
SharedClusterSparkIntegrationTestBase
             return df;
         };
     }
-
-    @SuppressWarnings("SameParameterValue")
-    private static StructType getWriteSchema(boolean addTTLColumn, boolean 
addTimestampColumn)
-    {
-        StructType schema = new StructType()
-                            .add("id", LongType, false)
-                            .add("course", BinaryType, false)
-                            .add("marks", LongType, false);
-        if (addTTLColumn)
-        {
-            schema = schema.add("ttl", IntegerType, false);
-        }
-        if (addTimestampColumn)
-        {
-            schema = schema.add("timestamp", LongType, false);
-        }
-        return schema;
-    }
-
-    @NotNull
-    @SuppressWarnings("SameParameterValue")
-    private static Row generateCourse(long recordNumber, Integer ttl, Long 
timestamp)
-    {
-        String courseNameString = String.valueOf(recordNumber);
-        int courseNameStringLen = courseNameString.length();
-        int courseNameMultiplier = 1000 / courseNameStringLen;
-        byte[] courseName = dupStringAsBytes(courseNameString, 
courseNameMultiplier);
-        ArrayList<Object> values = new ArrayList<>(Arrays.asList(recordNumber, 
courseName, recordNumber));
-        if (ttl != null)
-        {
-            values.add(ttl);
-        }
-        if (timestamp != null)
-        {
-            values.add(timestamp);
-        }
-        return RowFactory.create(values.toArray());
-    }
-
-    private static byte[] dupStringAsBytes(String string, Integer times)
-    {
-        byte[] stringBytes = string.getBytes();
-        ByteBuffer buffer = ByteBuffer.allocate(stringBytes.length * times);
-        for (int time = 0; time < times; time++)
-        {
-            buffer.put(stringBytes);
-        }
-        return buffer.array();
-    }
 }
diff --git 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkTestUtils.java
 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkTestUtils.java
index 46b70a9..a4a4734 100644
--- 
a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkTestUtils.java
+++ 
b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SparkTestUtils.java
@@ -22,6 +22,7 @@ package org.apache.cassandra.analytics;
 import java.net.UnknownHostException;
 import java.util.Arrays;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 import java.util.UUID;
@@ -32,6 +33,7 @@ import org.apache.cassandra.distributed.api.ICluster;
 import org.apache.cassandra.distributed.api.IInstance;
 import org.apache.cassandra.distributed.shared.JMXUtil;
 import org.apache.cassandra.sidecar.common.dns.DnsResolver;
+import org.apache.cassandra.sidecar.testing.MtlsTestHelper;
 import org.apache.cassandra.sidecar.testing.QualifiedName;
 import org.apache.cassandra.spark.KryoRegister;
 import org.apache.cassandra.spark.bulkwriter.BulkSparkConf;
@@ -52,8 +54,9 @@ import static org.assertj.core.api.Assertions.assertThat;
 public class SparkTestUtils
 {
     protected ICluster<? extends IInstance> cluster;
-    private DnsResolver dnsResolver;
-    private int sidecarPort;
+    protected DnsResolver dnsResolver;
+    protected int sidecarPort;
+    protected MtlsTestHelper mtlsTestHelper;
 
     /**
      * Runs any initialization code required for the tests
@@ -62,10 +65,12 @@ public class SparkTestUtils
      * @param dnsResolver the DNS resolver used to lookup replicas
      * @param sidecarPort the port where Sidecar is running
      */
-    public void initialize(ICluster<? extends IInstance> cluster, DnsResolver 
dnsResolver, int sidecarPort)
+    public void initialize(ICluster<? extends IInstance> cluster, DnsResolver 
dnsResolver, int sidecarPort,
+                           MtlsTestHelper mtlsTestHelper)
     {
         this.cluster = Objects.requireNonNull(cluster, "cluster is required");
         this.dnsResolver = Objects.requireNonNull(dnsResolver, "dnsResolver is 
required");
+        this.mtlsTestHelper = Objects.requireNonNull(mtlsTestHelper, 
"mtlsTestHelper is required");
         this.sidecarPort = sidecarPort;
     }
 
@@ -110,6 +115,7 @@ public class SparkTestUtils
                   .option("defaultParallelism", sc.defaultParallelism())
                   .option("numCores", numCores)
                   .option("sizing", "default")
+                  .options(mtlsTestHelper.mtlOptionMap())
                   .option("sidecar_port", sidecarPort);
     }
 
@@ -117,12 +123,13 @@ public class SparkTestUtils
      * Returns a {@link DataFrameWriter<Row>} with default options for 
performing a bulk write test, including
      * required parameters.
      *
-     * @param df        the source data frame
-     * @param tableName the qualified name of the table
+     * @param df                the source data frame
+     * @param tableName         the qualified name of the table
+     * @param additionalOptions additional options for the data frame
      * @return a {@link DataFrameWriter<Row>} with default options for 
performing a bulk write test
      */
-    public DataFrameWriter<Row> defaultBulkWriterDataFrameWriter(Dataset<Row> 
df,
-                                                                 QualifiedName 
tableName)
+    public DataFrameWriter<Row> defaultBulkWriterDataFrameWriter(Dataset<Row> 
df, QualifiedName tableName,
+                                                                 Map<String, 
String> additionalOptions)
     {
         return df.write()
                  
.format("org.apache.cassandra.spark.sparksql.CassandraDataSink")
@@ -133,6 +140,8 @@ public class SparkTestUtils
                  .option("bulk_writer_cl", "LOCAL_QUORUM")
                  .option("number_splits", "-1")
                  .option("sidecar_port", sidecarPort)
+                 .options(additionalOptions)
+                 .options(mtlsTestHelper.mtlOptionMap())
                  .mode("append");
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org
For additional commands, e-mail: commits-h...@cassandra.apache.org

Reply via email to