This is an automated email from the ASF dual-hosted git repository. frankgh pushed a commit to branch trunk in repository https://gitbox.apache.org/repos/asf/cassandra-analytics.git
The following commit(s) were added to refs/heads/trunk by this push: new 98baab1 Make sure bridge exists 98baab1 is described below commit 98baab1b8f0d5d7eb93f8d13db3b0a7a985fb03a Author: Doug Rohrer <droh...@apple.com> AuthorDate: Tue Feb 27 22:03:04 2024 -0500 Make sure bridge exists --- .circleci/config.yml | 32 ++--- CHANGES.txt | 1 + .../spark/bulkwriter/BulkWriterContext.java | 3 + .../bulkwriter/CassandraBulkWriterContext.java | 26 +++- .../spark/bulkwriter/CassandraSchemaInfo.java | 14 +- .../spark/bulkwriter/CqlTableInfoProvider.java | 17 ++- .../cassandra/spark/bulkwriter/RecordWriter.java | 77 +++++++++-- .../cassandra/spark/bulkwriter/SSTableWriter.java | 8 +- .../spark/bulkwriter/SSTableWriterFactory.java | 4 + .../cassandra/spark/bulkwriter/SchemaInfo.java | 6 + .../spark/bulkwriter/SqlToCqlTypeConverter.java | 127 ++++++++++++------ .../spark/bulkwriter/token/TokenUtils.java | 3 +- .../cassandra/spark/data/LocalDataLayer.java | 17 +++ .../spark/bulkwriter/MockBulkWriterContext.java | 14 ++ .../bulkwriter/SqlToCqlTypeConverterTest.java | 2 + .../spark/bulkwriter/TableSchemaNormalizeTest.java | 28 +++- .../spark/bulkwriter/TableSchemaTest.java | 3 +- .../spark/bulkwriter/TableSchemaTestCommon.java | 31 +++++ .../testing/SharedClusterIntegrationTestBase.java | 53 +++++++- cassandra-analytics-integration-tests/build.gradle | 1 + .../cassandra/analytics/BulkWriteUdtTest.java | 145 +++++++++++++++++++++ .../cassandra/analytics/DataGenerationUtils.java | 77 ++++++++++- .../analytics/QuoteIdentifiersWriteTest.java | 71 +++++++++- .../SharedClusterSparkIntegrationTestBase.java | 88 +++++++++++++ .../apache/cassandra/bridge/CassandraBridge.java | 1 + .../cassandra/spark/data/BridgeUdtValue.java | 69 ++++++++++ .../bridge/CassandraBridgeImplementation.java | 4 +- .../bridge/SSTableWriterImplementation.java | 30 +++-- .../cassandra/spark/data/complex/CqlUdt.java | 4 + .../bridge/SSTableWriterImplementationTest.java | 2 + 30 files changed, 854 insertions(+), 104 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 9d5d7ef..b7603fc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -147,7 +147,7 @@ jobs: - "*.jar" - "org/**/*" - cassandra-analytics-core-spark2-2_11-jdk8: + spark2-2_11-jdk8: docker: - image: cimg/openjdk:8.0 resource_class: large @@ -172,7 +172,7 @@ jobs: - store_test_results: path: build/test-reports - cassandra-analytics-core-int-spark2-2_11-jdk8: + int-spark2-2_11-jdk8: parallelism: 8 docker: - image: cimg/openjdk:8.0 @@ -198,7 +198,7 @@ jobs: - store_test_results: path: build/test-reports - cassandra-analytics-core-spark2-2_12-jdk8: + spark2-2_12-jdk8: docker: - image: cimg/openjdk:8.0 resource_class: large @@ -223,7 +223,7 @@ jobs: - store_test_results: path: build/test-reports - cassandra-analytics-core-int-spark2-2_12-jdk8: + int-spark2-2_12-jdk8: parallelism: 8 docker: - image: cimg/openjdk:8.0 @@ -249,7 +249,7 @@ jobs: - store_test_results: path: build/test-reports - cassandra-analytics-core-spark3-2_12-jdk11: + spark3-2_12-jdk11: docker: - image: cimg/openjdk:11.0 resource_class: large @@ -275,7 +275,7 @@ jobs: - store_test_results: path: build/test-reports - cassandra-analytics-core-int-spark3-2_12-jdk11: + int-spark3-2_12-jdk11: parallelism: 8 docker: - image: cimg/openjdk:11.0 @@ -302,7 +302,7 @@ jobs: - store_test_results: path: build/test-reports - cassandra-analytics-core-spark3-2_13-jdk11: + spark3-2_13-jdk11: docker: - image: cimg/openjdk:11.0 resource_class: large @@ -328,7 +328,7 @@ jobs: - store_test_results: path: build/test-reports - cassandra-analytics-core-int-spark3-2_13-jdk11: + int-spark3-2_13-jdk11: parallelism: 8 docker: - image: cimg/openjdk:11.0 @@ -361,27 +361,27 @@ workflows: jobs: - build-dependencies-jdk8 - build-dependencies-jdk11 - - cassandra-analytics-core-spark2-2_11-jdk8: + - spark2-2_11-jdk8: requires: - build-dependencies-jdk8 - - cassandra-analytics-core-spark2-2_12-jdk8: + - spark2-2_12-jdk8: requires: - build-dependencies-jdk8 - - cassandra-analytics-core-spark3-2_12-jdk11: + - spark3-2_12-jdk11: requires: - build-dependencies-jdk11 - - cassandra-analytics-core-spark3-2_13-jdk11: + - spark3-2_13-jdk11: requires: - build-dependencies-jdk11 - - cassandra-analytics-core-int-spark2-2_11-jdk8: + - int-spark2-2_11-jdk8: requires: - build-dependencies-jdk8 - - cassandra-analytics-core-int-spark2-2_12-jdk8: + - int-spark2-2_12-jdk8: requires: - build-dependencies-jdk8 - - cassandra-analytics-core-int-spark3-2_12-jdk11: + - int-spark3-2_12-jdk11: requires: - build-dependencies-jdk11 - - cassandra-analytics-core-int-spark3-2_13-jdk11: + - int-spark3-2_13-jdk11: requires: - build-dependencies-jdk11 diff --git a/CHANGES.txt b/CHANGES.txt index 914d933..741584c 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,5 @@ 1.0.0 + * Support UDTs in the Bulk Writer (CASSANDRA-19340) * Fix bulk reads of multiple tables that potentially have the same data file name (CASSANDRA-19507) * Fix XXHash32Digest calculated digest value (CASSANDRA-19500) * Report additional bulk analytics job stats for instrumentation (CASSANDRA-19418) diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkWriterContext.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkWriterContext.java index 10928d4..945f8a2 100644 --- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkWriterContext.java +++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkWriterContext.java @@ -22,6 +22,7 @@ package org.apache.cassandra.spark.bulkwriter; import java.io.Serializable; import org.apache.cassandra.spark.common.stats.JobStatsPublisher; +import org.apache.cassandra.bridge.CassandraBridge; public interface BulkWriterContext extends Serializable { @@ -35,6 +36,8 @@ public interface BulkWriterContext extends Serializable DataTransferApi transfer(); + CassandraBridge bridge(); + // NOTE: This interface intentionally does *not* implement AutoClosable as Spark can close Broadcast variables // that implement AutoClosable while they are still in use, causing the underlying object to become unusable void shutdown(); diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraBulkWriterContext.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraBulkWriterContext.java index 84d100c..0999604 100644 --- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraBulkWriterContext.java +++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraBulkWriterContext.java @@ -54,6 +54,8 @@ public class CassandraBulkWriterContext implements BulkWriterContext, KryoSerial @NotNull private final BulkSparkConf conf; private final JobInfo jobInfo; + private final String lowestCassandraVersion; + private transient CassandraBridge bridge; private transient DataTransferApi dataTransferApi; private final CassandraClusterInfo clusterInfo; private final SchemaInfo schemaInfo; @@ -68,9 +70,8 @@ public class CassandraBulkWriterContext implements BulkWriterContext, KryoSerial this.conf = conf; this.clusterInfo = clusterInfo; this.jobStatsPublisher = new LogStatsPublisher(); - String lowestCassandraVersion = clusterInfo.getLowestCassandraVersion(); - CassandraBridge bridge = CassandraBridgeFactory.get(lowestCassandraVersion); - + lowestCassandraVersion = clusterInfo.getLowestCassandraVersion(); + this.bridge = CassandraBridgeFactory.get(lowestCassandraVersion); TokenRangeMapping<RingInstance> tokenRangeMapping = clusterInfo.getTokenRangeMapping(true); jobInfo = new CassandraJobInfo(conf, new TokenPartitioner(tokenRangeMapping, @@ -92,11 +93,23 @@ public class CassandraBulkWriterContext implements BulkWriterContext, KryoSerial Set<String> udts = CqlUtils.extractUdts(keyspaceSchema, keyspace); ReplicationFactor replicationFactor = CqlUtils.extractReplicationFactor(keyspaceSchema, keyspace); int indexCount = CqlUtils.extractIndexCount(keyspaceSchema, keyspace, table); - CqlTable cqlTable = bridge.buildSchema(createTableSchema, keyspace, replicationFactor, partitioner, udts, null, indexCount); + CqlTable cqlTable = bridge().buildSchema(createTableSchema, keyspace, replicationFactor, partitioner, udts, null, indexCount); TableInfoProvider tableInfoProvider = new CqlTableInfoProvider(createTableSchema, cqlTable); TableSchema tableSchema = initializeTableSchema(conf, dfSchema, tableInfoProvider, lowestCassandraVersion); - schemaInfo = new CassandraSchemaInfo(tableSchema); + schemaInfo = new CassandraSchemaInfo(tableSchema, udts, cqlTable); + } + + @Override + public CassandraBridge bridge() + { + CassandraBridge currentBridge = this.bridge; + if (currentBridge != null) + { + return currentBridge; + } + this.bridge = CassandraBridgeFactory.get(lowestCassandraVersion); + return bridge; } public static BulkWriterContext fromOptions(@NotNull SparkContext sparkContext, @@ -204,9 +217,8 @@ public class CassandraBulkWriterContext implements BulkWriterContext, KryoSerial { if (dataTransferApi == null) { - CassandraBridge bridge = CassandraBridgeFactory.get(clusterInfo.getLowestCassandraVersion()); dataTransferApi = new SidecarDataTransferApi(clusterInfo.getCassandraContext(), - bridge, + bridge(), jobInfo, conf); } diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraSchemaInfo.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraSchemaInfo.java index d55b49b..7c63320 100644 --- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraSchemaInfo.java +++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraSchemaInfo.java @@ -19,14 +19,20 @@ package org.apache.cassandra.spark.bulkwriter; +import java.util.Set; + +import org.apache.cassandra.spark.data.CqlTable; + public class CassandraSchemaInfo implements SchemaInfo { private static final long serialVersionUID = -2327383232935001862L; private final TableSchema tableSchema; + private final Set<String> userDefinedTypeStatements; - public CassandraSchemaInfo(TableSchema tableSchema) + public CassandraSchemaInfo(TableSchema tableSchema, Set<String> userDefinedTypeStatements, CqlTable cqlTable) { this.tableSchema = tableSchema; + this.userDefinedTypeStatements = userDefinedTypeStatements; } @Override @@ -34,4 +40,10 @@ public class CassandraSchemaInfo implements SchemaInfo { return tableSchema; } + + @Override + public Set<String> getUserDefinedTypeStatements() + { + return userDefinedTypeStatements; + } } diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CqlTableInfoProvider.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CqlTableInfoProvider.java index e512d27..29977cf 100644 --- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CqlTableInfoProvider.java +++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CqlTableInfoProvider.java @@ -98,9 +98,20 @@ public class CqlTableInfoProvider implements TableInfoProvider @Override public List<ColumnType<?>> getPartitionKeyTypes() { - return cqlTable.partitionKeys().stream() - .map(cqlField -> DATA_TYPES.get(cqlField.type().cqlName().toLowerCase())) - .collect(Collectors.toList()); + List<ColumnType<?>> types = cqlTable.partitionKeys().stream() + .map(cqlField -> { + String typeName = cqlField.type().cqlName().toLowerCase(); + ColumnType<?> type = DATA_TYPES.get(typeName); + if (type == null) + { + throw new RuntimeException( + "Could not find ColumnType for type name" + typeName); + } + return type; + }) + .collect(Collectors.toList()); + return types; + } @Override diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java index 232c461..b6fea5a 100644 --- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java +++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java @@ -37,6 +37,7 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -44,6 +45,7 @@ import java.util.stream.Stream; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Range; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -51,6 +53,10 @@ import org.slf4j.LoggerFactory; import o.a.c.sidecar.client.shaded.common.data.TimeSkewResponse; import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler; import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping; +import org.apache.cassandra.spark.data.BridgeUdtValue; +import org.apache.cassandra.spark.data.CqlField; +import org.apache.cassandra.spark.data.CqlTable; +import org.apache.cassandra.spark.data.ReplicationFactor; import org.apache.cassandra.spark.utils.DigestAlgorithm; import org.apache.spark.InterruptibleIterator; import org.apache.spark.TaskContext; @@ -61,6 +67,8 @@ import static org.apache.cassandra.spark.utils.ScalaConversionUtils.asScalaItera @SuppressWarnings({ "ConstantConditions" }) public class RecordWriter implements Serializable { + public static final ReplicationFactor IGNORED_REPLICATION_FACTOR = new ReplicationFactor(ReplicationFactor.ReplicationStrategy.SimpleStrategy, + ImmutableMap.of("replication_factor", 1)); private static final Logger LOGGER = LoggerFactory.getLogger(RecordWriter.class); private static final long serialVersionUID = 3746578054834640428L; private final BulkWriterContext writerContext; @@ -72,8 +80,10 @@ public class RecordWriter implements Serializable private final ReplicaAwareFailureHandler<RingInstance> failureHandler; private final Supplier<TaskContext> taskContextSupplier; + private final ConcurrentHashMap<String, CqlField.CqlUdt> udtCache = new ConcurrentHashMap<>(); private SSTableWriter sstableWriter = null; private int outputSequence = 0; // sub-folder for possible subrange splits + private transient volatile CqlTable cqlTable; public RecordWriter(BulkWriterContext writerContext, String[] columnNames) { @@ -107,6 +117,21 @@ public class RecordWriter implements Serializable return String.format("%d-%s", taskContext.partitionId(), UUID.randomUUID()); } + private CqlTable cqlTable() + { + if (cqlTable == null) + { + cqlTable = writerContext.bridge() + .buildSchema(writerContext.schema().getTableSchema().createStatement, + writerContext.job().keyspace(), + IGNORED_REPLICATION_FACTOR, + writerContext.cluster().getPartitioner(), + writerContext.schema().getUserDefinedTypeStatements()); + } + + return cqlTable; + } + public WriteResult write(Iterator<Tuple2<DecoratedKey, Object[]>> sourceIterator) { TaskContext taskContext = taskContextSupplier.get(); @@ -209,6 +234,14 @@ public class RecordWriter implements Serializable } } + public static <T> Set<T> symmetricDifference(Set<T> set1, Set<T> set2) + { + return Stream.concat( + set1.stream().filter(element -> !set2.contains(element)), + set2.stream().filter(element -> !set1.contains(element))) + .collect(Collectors.toSet()); + } + private Map<Range<BigInteger>, List<RingInstance>> taskTokenRangeMapping(TokenRangeMapping<RingInstance> tokenRange, Range<BigInteger> taskTokenRange) { @@ -308,14 +341,6 @@ public class RecordWriter implements Serializable } } - public static <T> Set<T> symmetricDifference(Set<T> set1, Set<T> set2) - { - return Stream.concat( - set1.stream().filter(element -> !set2.contains(element)), - set2.stream().filter(element -> !set1.contains(element))) - .collect(Collectors.toSet()); - } - private void validateAcceptableTimeSkewOrThrow(List<RingInstance> replicas) { if (replicas.isEmpty()) @@ -370,16 +395,48 @@ public class RecordWriter implements Serializable } } - private static Map<String, Object> getBindValuesForColumns(Map<String, Object> map, String[] columnNames, Object[] values) + private Map<String, Object> getBindValuesForColumns(Map<String, Object> map, String[] columnNames, Object[] values) { assert values.length == columnNames.length : "Number of values does not match the number of columns " + values.length + ", " + columnNames.length; for (int i = 0; i < columnNames.length; i++) { - map.put(columnNames[i], values[i]); + map.put(columnNames[i], maybeConvertUdt(values[i])); } return map; } + private Object maybeConvertUdt(Object value) + { + if (value instanceof BridgeUdtValue) + { + BridgeUdtValue udtValue = (BridgeUdtValue) value; + // Depth-first replacement of BridgeUdtValue instances to their appropriate Cql types + for (Map.Entry<String, Object> entry : udtValue.udtMap.entrySet()) + { + if (entry.getValue() instanceof BridgeUdtValue) + { + udtValue.udtMap.put(entry.getKey(), maybeConvertUdt(entry.getValue())); + } + } + return getUdt(udtValue.name).convertForCqlWriter(udtValue.udtMap, writerContext.bridge().getVersion()); + } + return value; + } + + private synchronized CqlField.CqlType getUdt(String udtName) + { + return udtCache.computeIfAbsent(udtName, name -> { + for (CqlField.CqlUdt udt1 : cqlTable().udts()) + { + if (udt1.cqlName().equals(name)) + { + return udt1; + } + } + throw new IllegalArgumentException("Could not find udt with name " + name); + }); + } + /** * Close the {@link RecordWriter#sstableWriter} if present. Schedule a stream session with the produced sstables. * And finally, nullify {@link RecordWriter#sstableWriter} diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SSTableWriter.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SSTableWriter.java index 8c1a35f..addbc11 100644 --- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SSTableWriter.java +++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SSTableWriter.java @@ -27,6 +27,7 @@ import java.nio.file.Path; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Set; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Range; @@ -78,13 +79,15 @@ public class SSTableWriter String packageVersion = getPackageVersion(lowestCassandraVersion); LOGGER.info("Running with version " + packageVersion); - TableSchema tableSchema = writerContext.schema().getTableSchema(); + SchemaInfo schema = writerContext.schema(); + TableSchema tableSchema = schema.getTableSchema(); this.cqlSSTableWriter = SSTableWriterFactory.getSSTableWriter( CassandraVersionFeatures.cassandraVersionFeaturesFromCassandraVersion(packageVersion), this.outDir.toString(), writerContext.cluster().getPartitioner().toString(), tableSchema.createStatement, tableSchema.modificationStatement, + schema.getUserDefinedTypeStatements(), writerContext.job().sstableDataSizeInMiB()); } @@ -137,8 +140,9 @@ public class SSTableWriter CassandraVersion version = CassandraBridgeFactory.getCassandraVersion(writerContext.cluster().getLowestCassandraVersion()); String keyspace = writerContext.job().keyspace(); String schema = writerContext.schema().getTableSchema().createStatement; + Set<String> udtStatements = writerContext.schema().getUserDefinedTypeStatements(); String directory = getOutDir().toString(); - DataLayer layer = new LocalDataLayer(version, keyspace, schema, directory); + DataLayer layer = new LocalDataLayer(version, keyspace, schema, udtStatements, directory); try (StreamScanner<Rid> scanner = layer.openCompactionScanner(partitionId, Collections.emptyList(), null)) { while (scanner.hasNext()) diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SSTableWriterFactory.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SSTableWriterFactory.java index 55cace3..77b8f5f 100644 --- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SSTableWriterFactory.java +++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SSTableWriterFactory.java @@ -19,6 +19,8 @@ package org.apache.cassandra.spark.bulkwriter; +import java.util.Set; + import org.apache.cassandra.bridge.CassandraBridge; import org.apache.cassandra.bridge.CassandraBridgeFactory; import org.apache.cassandra.bridge.CassandraVersionFeatures; @@ -36,6 +38,7 @@ public final class SSTableWriterFactory String partitioner, String createStatement, String insertStatement, + Set<String> userDefinedTypeStatements, int bufferSizeMB) { CassandraBridge cassandraBridge = CassandraBridgeFactory.get(serverVersion); @@ -43,6 +46,7 @@ public final class SSTableWriterFactory partitioner, createStatement, insertStatement, + userDefinedTypeStatements, bufferSizeMB); } } diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SchemaInfo.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SchemaInfo.java index ca95618..0257d29 100644 --- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SchemaInfo.java +++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SchemaInfo.java @@ -20,8 +20,14 @@ package org.apache.cassandra.spark.bulkwriter; import java.io.Serializable; +import java.util.Set; + +import org.jetbrains.annotations.NotNull; public interface SchemaInfo extends Serializable { TableSchema getTableSchema(); + + @NotNull + Set<String> getUserDefinedTypeStatements(); } diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SqlToCqlTypeConverter.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SqlToCqlTypeConverter.java index a19ca21..a80c8c6 100644 --- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SqlToCqlTypeConverter.java +++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SqlToCqlTypeConverter.java @@ -41,8 +41,10 @@ import com.google.common.net.InetAddresses; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.cassandra.spark.data.BridgeUdtValue; import org.apache.cassandra.spark.data.CqlField; import org.apache.cassandra.spark.utils.UUIDs; +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; import scala.Tuple2; import static org.apache.cassandra.spark.utils.ScalaConversionUtils.asJavaIterable; @@ -50,8 +52,6 @@ import static org.apache.cassandra.spark.utils.ScalaConversionUtils.asJavaIterab @SuppressWarnings("unchecked") public final class SqlToCqlTypeConverter implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(SqlToCqlTypeConverter.class); - public static final String ASCII = "ascii"; public static final String BIGINT = "bigint"; public static final String BLOB = "blob"; @@ -79,7 +79,7 @@ public final class SqlToCqlTypeConverter implements Serializable public static final String UDT = "udt"; public static final String VARCHAR = "varchar"; public static final String VARINT = "varint"; - + private static final Logger LOGGER = LoggerFactory.getLogger(SqlToCqlTypeConverter.class); private static final NoOp<Object> NO_OP_CONVERTER = new NoOp<>(); private static final LongConverter LONG_CONVERTER = new LongConverter(); private static final BytesConverter BYTES_CONVERTER = new BytesConverter(); @@ -164,11 +164,14 @@ public final class SqlToCqlTypeConverter implements Serializable return new MapConverter<>((CqlField.CqlMap) cqlType); case SET: return new SetConverter<>((CqlField.CqlCollection) cqlType); - case UDT: - return NO_OP_CONVERTER; case TUPLE: return NO_OP_CONVERTER; default: + if (cqlType.internalType() == CqlField.CqlType.InternalType.Udt) + { + assert cqlType instanceof CqlField.CqlUdt; + return new UdtConverter((CqlField.CqlUdt) cqlType); + } LOGGER.warn("Unable to match type={}. Defaulting to NoOp Converter", cqlName); return NO_OP_CONVERTER; } @@ -212,12 +215,12 @@ public final class SqlToCqlTypeConverter implements Serializable abstract static class Converter<T> implements Serializable { - public abstract T convertInternal(Object object) throws RuntimeException; - public T convert(Object object) { return convertInternal(object); } + + abstract T convertInternal(Object object); } private abstract static class NullableConverter<T> extends Converter<T> @@ -441,7 +444,7 @@ public final class SqlToCqlTypeConverter implements Serializable * @throws RuntimeException when the object cannot be converted to timestamp */ @Override - public Long convertInternal(Object object) throws RuntimeException + public Long convertInternal(Object object) { if (object instanceof Date) { @@ -479,7 +482,7 @@ public final class SqlToCqlTypeConverter implements Serializable * @throws RuntimeException when the object cannot be converted to timestamp */ @Override - public Date convertInternal(Object object) throws RuntimeException + public Date convertInternal(Object object) { if (object instanceof Date) { @@ -510,22 +513,6 @@ public final class SqlToCqlTypeConverter implements Serializable return "Date"; } - protected int fromDate(Date value) - { - long millisSinceEpoch = value.getTime(); - return fromMillisSinceEpoch(millisSinceEpoch); - } - - protected int fromMillisSinceEpoch(long millisSinceEpoch) - { - // NOTE: This code is lifted from org.apache.cassandra.serializers.SimpleDateSerializer#timeInMillisToDay. - // Reproduced here due to the difficulties of referencing classes from specific versions of Cassandra - // in the SBW. - int result = (int) TimeUnit.MILLISECONDS.toDays(millisSinceEpoch); - result -= Integer.MIN_VALUE; - return result; - } - @Override public Integer convertInternal(Object object) { @@ -542,6 +529,22 @@ public final class SqlToCqlTypeConverter implements Serializable throw new RuntimeException("Unsupported conversion for DATE from " + object.getClass().getTypeName()); } } + + protected int fromDate(Date value) + { + long millisSinceEpoch = value.getTime(); + return fromMillisSinceEpoch(millisSinceEpoch); + } + + protected int fromMillisSinceEpoch(long millisSinceEpoch) + { + // NOTE: This code is lifted from org.apache.cassandra.serializers.SimpleDateSerializer#timeInMillisToDay. + // Reproduced here due to the difficulties of referencing classes from specific versions of Cassandra + // in the SBW. + int result = (int) TimeUnit.MILLISECONDS.toDays(millisSinceEpoch); + result -= Integer.MIN_VALUE; + return result; + } } static class TimeConverter extends NullableConverter<Long> @@ -674,6 +677,12 @@ public final class SqlToCqlTypeConverter implements Serializable } } + @Override + public String toString() + { + return "List"; + } + private List<E> makeList(Iterable<?> iterable) { List<E> list = new ArrayList<>(); @@ -683,12 +692,6 @@ public final class SqlToCqlTypeConverter implements Serializable } return list; } - - @Override - public String toString() - { - return "List"; - } } static class SetConverter<E> extends NullableConverter<Set<E>> @@ -717,6 +720,12 @@ public final class SqlToCqlTypeConverter implements Serializable } } + @Override + public String toString() + { + return "Set<" + innerConverter.toString() + ">"; + } + private Set<E> makeSet(Iterable<?> iterable) { Set<E> set = new HashSet<>(); @@ -726,12 +735,6 @@ public final class SqlToCqlTypeConverter implements Serializable } return set; } - - @Override - public String toString() - { - return "Set<" + innerConverter.toString() + ">"; - } } static class MapConverter<K, V> extends NullableConverter<Map<K, V>> @@ -763,6 +766,12 @@ public final class SqlToCqlTypeConverter implements Serializable throw new RuntimeException("Unsupported conversion for MAP from " + object.getClass().getTypeName()); } + @Override + public String toString() + { + return "Map<" + keyConverter.toString() + ", " + valConverter.toString() + '>'; + } + private Map<K, V> makeMap(Iterable<?> iterable) { Object key; @@ -788,11 +797,53 @@ public final class SqlToCqlTypeConverter implements Serializable } return map; } + } + + public static class UdtConverter extends NullableConverter<BridgeUdtValue> + { + private final String name; + private final HashMap<String, Converter<?>> converters; + + UdtConverter(CqlField.CqlUdt udt) + { + this.name = udt.cqlName(); + this.converters = new HashMap<>(); + for (CqlField f : udt.fields()) + { + converters.put(f.name(), getConverter(f.type())); + } + } + + @Override + public BridgeUdtValue convertInternal(Object object) + { + if (object instanceof GenericRowWithSchema) + { + Map<String, Object> udtMap = makeUdtMap((GenericRowWithSchema) object); + return new BridgeUdtValue(name, udtMap); + } + throw new RuntimeException("Unsupported conversion for UDT from " + object.getClass().getTypeName()); + } @Override public String toString() { - return "Map"; + return String.format("UDT[%s]", name); + } + + // Unfortunately, we don't have easy access to the bridge here. + // Rather than trying to create an actual UDTValue here, we will push + // that responsibility down to the SSTableWriter Implementation + private Map<String, Object> makeUdtMap(GenericRowWithSchema row) + { + Map<String, Object> result = new HashMap<>(); + for (String fieldName : row.schema().fieldNames()) + { + Converter<?> converter = converters.get(fieldName); + Object val = row.get(row.fieldIndex(fieldName)); + result.put(fieldName, converter.convert(val)); + } + return result; } } } diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/token/TokenUtils.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/token/TokenUtils.java index b0169ae..d9b0e67 100644 --- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/token/TokenUtils.java +++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/token/TokenUtils.java @@ -38,7 +38,7 @@ import org.apache.cassandra.spark.data.partitioner.Partitioner; * This reduces number of SSTables that get created in Cassandra by the bulk writing job. * Fewer SSTables will result in lower read latencies and lower compaction overhead. */ -@SuppressWarnings("WeakerAccess") +@SuppressWarnings({"WeakerAccess", "rawtypes", "unchecked"}) public class TokenUtils implements Serializable { private final String[] partitionKeyColumns; @@ -54,7 +54,6 @@ public class TokenUtils implements Serializable this.isMurmur3Partitioner = isMurmur3Partitioner; } - // noinspection unchecked private ByteBuffer getByteBuffer(Object columnValue, int partitionKeyColumnIdx) { ColumnType columnType = partitionKeyColumnTypes[partitionKeyColumnIdx]; diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/LocalDataLayer.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/LocalDataLayer.java index 214d5cc..1f79d7e 100644 --- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/LocalDataLayer.java +++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/LocalDataLayer.java @@ -195,6 +195,23 @@ public class LocalDataLayer extends DataLayer implements Serializable paths); } + public LocalDataLayer(@NotNull CassandraVersion version, + @NotNull String keyspace, + @NotNull String createStatement, + @NotNull Set<String> udtStatements, + String... paths) + { + this(version, + Partitioner.Murmur3Partitioner, + keyspace, + createStatement, + udtStatements, + Collections.emptyList(), + false, + null, + paths); + } + // CHECKSTYLE IGNORE: Constructor with many parameters public LocalDataLayer(@NotNull CassandraVersion version, @NotNull Partitioner partitioner, diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockBulkWriterContext.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockBulkWriterContext.java index 91e7174..cf6a6f5 100644 --- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockBulkWriterContext.java +++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockBulkWriterContext.java @@ -40,7 +40,9 @@ import com.google.common.collect.ImmutableMap; import org.apache.commons.lang3.tuple.Pair; import o.a.c.sidecar.client.shaded.common.data.TimeSkewResponse; +import org.apache.cassandra.bridge.CassandraBridge; import org.apache.cassandra.bridge.CassandraBridgeFactory; +import org.apache.cassandra.bridge.CassandraVersion; import org.apache.cassandra.spark.bulkwriter.token.ConsistencyLevel; import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping; import org.apache.cassandra.spark.common.Digest; @@ -77,6 +79,7 @@ public class MockBulkWriterContext implements BulkWriterContext, ClusterInfo, Jo private ConsistencyLevel.CL consistencyLevel; private int sstableDataSizeInMB = 128; private int sstableWriteBatchSize = 2; + private CassandraBridge bridge = CassandraBridgeFactory.get(CassandraVersion.FOURZERO); @Override public void publish(Map<String, String> stats) @@ -295,6 +298,12 @@ public class MockBulkWriterContext implements BulkWriterContext, ClusterInfo, Jo return schema; } + @Override + public Set<String> getUserDefinedTypeStatements() + { + return Collections.emptySet(); + } + @Override public Partitioner getPartitioner() { @@ -446,6 +455,11 @@ public class MockBulkWriterContext implements BulkWriterContext, ClusterInfo, Jo return this; } + public CassandraBridge bridge() + { + return bridge; + } + @Override public boolean quoteIdentifiers() { diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SqlToCqlTypeConverterTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SqlToCqlTypeConverterTest.java index ccfd7e6..c72bddd 100644 --- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SqlToCqlTypeConverterTest.java +++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SqlToCqlTypeConverterTest.java @@ -52,6 +52,7 @@ import static org.apache.cassandra.spark.bulkwriter.TableSchemaTestCommon.mockCq import static org.apache.cassandra.spark.bulkwriter.TableSchemaTestCommon.mockListCqlType; import static org.apache.cassandra.spark.bulkwriter.TableSchemaTestCommon.mockMapCqlType; import static org.apache.cassandra.spark.bulkwriter.TableSchemaTestCommon.mockSetCqlType; +import static org.apache.cassandra.spark.bulkwriter.TableSchemaTestCommon.mockUdtCqlType; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.MatcherAssert.assertThat; @@ -73,6 +74,7 @@ public final class SqlToCqlTypeConverterTest na(mockCqlType(DATE), SqlToCqlTypeConverter.DateConverter.class), na(mockMapCqlType(INT, INT), SqlToCqlTypeConverter.MapConverter.class), na(mockSetCqlType(INET), SqlToCqlTypeConverter.SetConverter.class), + na(mockUdtCqlType("udtType", "f1", TEXT, "f2", INT, "f3", TIMEUUID), SqlToCqlTypeConverter.UdtConverter.class), // Special Cassandra 1.2 Timestamp type should use TimestampConverter na(mockCqlCustom("org.apache.cassandra.db.marshal.DateType"), SqlToCqlTypeConverter.TimestampConverter.class)); } diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/TableSchemaNormalizeTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/TableSchemaNormalizeTest.java index 0cc628e..fe106bb 100644 --- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/TableSchemaNormalizeTest.java +++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/TableSchemaNormalizeTest.java @@ -34,6 +34,7 @@ import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; +import com.google.common.collect.ImmutableMap; import com.google.common.net.InetAddresses; import org.junit.jupiter.api.Test; @@ -42,8 +43,13 @@ import org.apache.cassandra.spark.common.schema.ColumnTypes; import org.apache.cassandra.spark.common.schema.ListType; import org.apache.cassandra.spark.common.schema.MapType; import org.apache.cassandra.spark.common.schema.SetType; +import org.apache.cassandra.spark.data.BridgeUdtValue; import org.apache.cassandra.spark.data.CqlField; +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import static java.util.AbstractMap.SimpleEntry; import static org.apache.cassandra.spark.bulkwriter.SqlToCqlTypeConverter.ASCII; @@ -73,6 +79,7 @@ import static org.apache.cassandra.spark.bulkwriter.TableSchemaTestCommon.mockCq import static org.apache.cassandra.spark.bulkwriter.TableSchemaTestCommon.mockListCqlType; import static org.apache.cassandra.spark.bulkwriter.TableSchemaTestCommon.mockMapCqlType; import static org.apache.cassandra.spark.bulkwriter.TableSchemaTestCommon.mockSetCqlType; +import static org.apache.cassandra.spark.bulkwriter.TableSchemaTestCommon.mockUdtCqlType; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.core.Is.is; import static org.hamcrest.core.IsEqual.equalTo; @@ -296,6 +303,24 @@ public class TableSchemaNormalizeTest DataTypes.createArrayType(DataTypes.createArrayType(DataTypes.BinaryType)))); } + @Test + public void testUdtNormalization() + { + StructType structType = new StructType() + .add(new StructField("f1", DataTypes.IntegerType, false, Metadata.empty())) + .add(new StructField("f2", DataTypes.StringType, false, Metadata.empty())); + + GenericRowWithSchema source = new GenericRowWithSchema(new Object[]{1, "course"}, structType); + // NOTE: UDT Types carry their type name around, so the use of `udt_field` consistently here is a bit + // "wrong" for the real-world, but is tested by integration tests elsewhere and is correct for the way + // the asserts in this test work. + BridgeUdtValue udtValue = new BridgeUdtValue("udt_field", ImmutableMap.of("f1", 1, "f2", "course")); + + CqlField.CqlUdt cqlType = mockUdtCqlType("udt_field", "f1", INT, "f2", TEXT); + assertNormalized("udt_field", cqlType, new MapType<>(ColumnTypes.STRING, new ListType<>(ColumnTypes.BYTES)), + source, udtValue, structType); + } + private void assertNormalized(String field, CqlField.CqlType cqlType, ColumnType<?> columnType, @@ -309,6 +334,7 @@ public class TableSchemaNormalizeTest TableSchema schema = buildSchema(fieldNames, sparkTypes, new CqlField.CqlType[]{cqlType}, fieldNames, cqlTypes, fieldNames); Object[] source = new Object[]{sourceVal}; Object[] expected = new Object[]{expectedVal}; - assertThat(schema.normalize(source), is(equalTo(expected))); + Object[] normalized = schema.normalize(source); + assertThat(normalized, is(equalTo(expected))); } } diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/TableSchemaTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/TableSchemaTest.java index 5cb184f..83564f5 100644 --- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/TableSchemaTest.java +++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/TableSchemaTest.java @@ -159,7 +159,8 @@ public class TableSchemaTest TableSchema schema = getValidSchemaBuilder() .build(); - assertThat(schema.normalize(new Object[]{1, 1L, "foo", 2}), is(equalTo(new Object[]{1, -2147483648, "foo", 2}))); + assertThat(schema.normalize(new Object[]{1, 1L, "foo", 2}), + is(equalTo(new Object[]{1, -2147483648, "foo", 2}))); } @Test diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/TableSchemaTestCommon.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/TableSchemaTestCommon.java index ef937ce..8dbc5fc 100644 --- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/TableSchemaTestCommon.java +++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/TableSchemaTestCommon.java @@ -19,7 +19,9 @@ package org.apache.cassandra.spark.bulkwriter; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -43,6 +45,7 @@ import static org.apache.cassandra.spark.bulkwriter.SqlToCqlTypeConverter.CUSTOM import static org.apache.cassandra.spark.bulkwriter.SqlToCqlTypeConverter.LIST; import static org.apache.cassandra.spark.bulkwriter.SqlToCqlTypeConverter.MAP; import static org.apache.cassandra.spark.bulkwriter.SqlToCqlTypeConverter.SET; +import static org.apache.cassandra.spark.bulkwriter.SqlToCqlTypeConverter.UDT; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -124,6 +127,34 @@ public final class TableSchemaTestCommon return mock; } + @NotNull + public static CqlField.CqlUdt mockUdtCqlType(String name, String... namesAndTypes) + { + assert namesAndTypes.length > 0 && (namesAndTypes.length % 2) == 0; + HashMap<String, CqlField> udtDef = new HashMap<>(); + CqlField.CqlUdt udtMock = mock(CqlField.CqlUdt.class); + when(udtMock.cqlName()).thenReturn(name); + when(udtMock.internalType()).thenReturn(CqlField.CqlType.InternalType.Udt); + when(udtMock.name()).thenReturn(UDT); + List<CqlField> fields = new ArrayList<>(); + for (int i = 0; i < namesAndTypes.length; i += 2) + { + String field = namesAndTypes[i]; + String type = namesAndTypes[i + 1]; + CqlField mock = mock(CqlField.class); + when(mock.name()).thenReturn(field); + when(mock.cqlTypeName()).thenReturn(type); + CqlField.CqlType fieldType = mockCqlType(type); + when(mock.type()).thenReturn(fieldType); + udtDef.put(field, mock); + when(udtMock.field(i / 2)).thenReturn(mock); + when(udtMock.field(field)).thenReturn(mock); + fields.add(mock); + } + when(udtMock.fields()).thenReturn(fields); + return udtMock; + } + public static TableSchema buildSchema(String[] fieldNames, org.apache.spark.sql.types.DataType[] sparkTypes, CqlField.CqlType[] driverTypes, diff --git a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java index 028b436..267c175 100644 --- a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java +++ b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java @@ -54,6 +54,8 @@ import io.vertx.junit5.VertxExtension; import io.vertx.junit5.VertxTestContext; import org.apache.cassandra.distributed.UpgradeableCluster; import org.apache.cassandra.distributed.api.ConsistencyLevel; +import org.apache.cassandra.distributed.api.Feature; +import org.apache.cassandra.distributed.api.ICluster; import org.apache.cassandra.distributed.api.IInstance; import org.apache.cassandra.distributed.api.IInstanceConfig; import org.apache.cassandra.distributed.impl.AbstractCluster; @@ -81,6 +83,10 @@ import org.apache.cassandra.testing.TestUtils; import org.apache.cassandra.testing.TestVersion; import org.apache.cassandra.testing.TestVersionSupplier; import org.apache.cassandra.utils.Throwables; +import shaded.com.datastax.driver.core.Cluster; +import shaded.com.datastax.driver.core.ResultSet; +import shaded.com.datastax.driver.core.Session; +import shaded.com.datastax.driver.core.SimpleStatement; import static org.apache.cassandra.sidecar.testing.CassandraSidecarTestContext.tryGetIntConfig; import static org.assertj.core.api.Assertions.assertThat; @@ -130,6 +136,7 @@ public abstract class SharedClusterIntegrationTestBase protected AbstractCluster<? extends IInstance> cluster; protected Server server; protected Injector injector; + protected TestVersion testVersion; static { @@ -140,10 +147,11 @@ public abstract class SharedClusterIntegrationTestBase @BeforeAll protected void setup() throws InterruptedException, IOException { - Optional<TestVersion> testVersion = TestVersionSupplier.testVersions().findFirst(); - assertThat(testVersion).isPresent(); + Optional<TestVersion> maybeTestVersion = TestVersionSupplier.testVersions().findFirst(); + assertThat(maybeTestVersion).isPresent(); + this.testVersion = maybeTestVersion.get(); logger.info("Testing with version={}", testVersion); - cluster = provisionClusterWithRetries(testVersion.get()); + cluster = provisionClusterWithRetries(this.testVersion); assertThat(cluster).isNotNull(); initializeSchemaForTest(); startSidecar(cluster); @@ -346,6 +354,45 @@ public abstract class SharedClusterIntegrationTestBase return cluster.coordinator(1).execute(String.format("SELECT * FROM %s;", table), consistencyLevel); } + /** + * Convenience method to query all data from the provided {@code table} at the specified consistency level. + * + * @param table the qualified Cassandra table name + * @param consistency + * @return all the data queried from the table + */ + protected ResultSet queryAllDataWithDriver(ICluster cluster, QualifiedName table, shaded.com.datastax.driver.core.ConsistencyLevel consistency) + { + Cluster driverCluster = createDriverCluster(cluster); + Session session = driverCluster.connect(); + SimpleStatement statement = new SimpleStatement(String.format("SELECT * FROM %s;", table)); + statement.setConsistencyLevel(consistency); + return session.execute(statement); + } + + public static Cluster createDriverCluster(ICluster<? extends IInstance> dtest) + { + if (dtest.size() == 0) + { + throw new IllegalArgumentException("Attempted to open java driver for empty cluster"); + } + else + { + dtest.stream().forEach((i) -> { + if (!i.config().has(Feature.NATIVE_PROTOCOL) || !i.config().has(Feature.GOSSIP)) + { + throw new IllegalStateException("java driver requires Feature.NATIVE_PROTOCOL and Feature.GOSSIP; but one or more is missing"); + } + }); + Cluster.Builder builder = Cluster.builder(); + dtest.stream().forEach((i) -> { + builder.addContactPointsWithPorts(new InetSocketAddress(i.broadcastAddress().getAddress(), i.config().getInt("native_transport_port"))); + }); + + return builder.build(); + } + } + static class IntegrationTestModule extends AbstractModule { private final AbstractCluster<? extends IInstance> cluster; diff --git a/cassandra-analytics-integration-tests/build.gradle b/cassandra-analytics-integration-tests/build.gradle index 711e09d..38cad7e 100644 --- a/cassandra-analytics-integration-tests/build.gradle +++ b/cassandra-analytics-integration-tests/build.gradle @@ -34,6 +34,7 @@ project(':cassandra-analytics-integration-tests') { } dependencies { + testImplementation(project(':cassandra-bridge')) testImplementation(project(':cassandra-analytics-core')) testImplementation(group: 'net.java.dev.jna', name: 'jna', version: '5.9.0') diff --git a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BulkWriteUdtTest.java b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BulkWriteUdtTest.java new file mode 100644 index 0000000..f8e7816 --- /dev/null +++ b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BulkWriteUdtTest.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.cassandra.analytics; + +import java.io.IOException; + +import org.junit.jupiter.api.Test; + +import com.vdurmont.semver4j.Semver; +import org.apache.cassandra.distributed.UpgradeableCluster; +import org.apache.cassandra.distributed.api.ConsistencyLevel; +import org.apache.cassandra.distributed.api.Feature; +import org.apache.cassandra.distributed.api.SimpleQueryResult; +import org.apache.cassandra.distributed.api.TokenSupplier; +import org.apache.cassandra.distributed.shared.Versions; +import org.apache.cassandra.sidecar.testing.JvmDTestSharedClassesPredicate; +import org.apache.cassandra.sidecar.testing.QualifiedName; +import org.apache.cassandra.testing.TestVersion; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.jetbrains.annotations.NotNull; + +import static org.apache.cassandra.testing.CassandraTestTemplate.fixDistributedSchemas; +import static org.apache.cassandra.testing.CassandraTestTemplate.waitForHealthyRing; +import static org.apache.cassandra.testing.TestUtils.DC1_RF3; +import static org.apache.cassandra.testing.TestUtils.ROW_COUNT; +import static org.apache.cassandra.testing.TestUtils.TEST_KEYSPACE; +import static org.assertj.core.api.Assertions.assertThat; + +class BulkWriteUdtTest extends SharedClusterSparkIntegrationTestBase +{ + static final QualifiedName UDT_TABLE_NAME = new QualifiedName(TEST_KEYSPACE, "test_udt"); + static final QualifiedName NESTED_TABLE_NAME = new QualifiedName(TEST_KEYSPACE, "test_nested_udt"); + public static final String TWO_FIELD_UDT_NAME = "two_field_udt"; + public static final String NESTED_FIELD_UDT_NAME = "nested_udt"; + public static final String UDT_TABLE_CREATE = "CREATE TABLE " + UDT_TABLE_NAME + " (\n" + + " id BIGINT PRIMARY KEY,\n" + + " udtfield " + TWO_FIELD_UDT_NAME + ");"; + public static final String TWO_FIELD_UDT_DEF = "CREATE TYPE " + UDT_TABLE_NAME.keyspace() + "." + + TWO_FIELD_UDT_NAME + " (\n" + + " f1 text,\n" + + " f2 int);"; + public static final String NESTED_UDT_DEF = "CREATE TYPE " + NESTED_TABLE_NAME.keyspace() + "." + + NESTED_FIELD_UDT_NAME + " (\n" + + " n1 BIGINT,\n" + + " n2 frozen<" + TWO_FIELD_UDT_NAME + ">" + + ");"; + public static final String NESTED_TABLE_CREATE = "CREATE TABLE " + NESTED_TABLE_NAME + "(\n" + + " id BIGINT PRIMARY KEY,\n" + + " nested " + NESTED_FIELD_UDT_NAME + ");"; + @Test + void testWriteWithUdt() + { + SparkSession spark = getOrCreateSparkSession(); + Dataset<Row> df = DataGenerationUtils.generateUdtData(spark, ROW_COUNT); + + bulkWriterDataFrameWriter(df, UDT_TABLE_NAME).save(); + + SimpleQueryResult result = cluster.coordinator(1).executeWithResult("SELECT * FROM " + UDT_TABLE_NAME, ConsistencyLevel.ALL); + assertThat(result.hasNext()).isTrue(); + validateWritesWithDriverResultSet(df.collectAsList(), + queryAllDataWithDriver(cluster, UDT_TABLE_NAME, + shaded.com.datastax.driver.core.ConsistencyLevel.LOCAL_QUORUM), + BulkWriteUdtTest::defaultRowFormatter); + } + + @Test + void testWriteWithNestedUdt() + { + SparkSession spark = getOrCreateSparkSession(); + Dataset<Row> df = DataGenerationUtils.generateNestedUdtData(spark, ROW_COUNT); + + bulkWriterDataFrameWriter(df, NESTED_TABLE_NAME).save(); + + SimpleQueryResult result = cluster.coordinator(1).executeWithResult("SELECT * FROM " + NESTED_TABLE_NAME, ConsistencyLevel.ALL); + assertThat(result.hasNext()).isTrue(); + validateWritesWithDriverResultSet(df.collectAsList(), + queryAllDataWithDriver(cluster, NESTED_TABLE_NAME, shaded.com.datastax.driver.core.ConsistencyLevel.LOCAL_QUORUM), + BulkWriteUdtTest::defaultRowFormatter); + } + + @NotNull + public static String defaultRowFormatter(shaded.com.datastax.driver.core.Row row) + { + return row.getLong(0) + + ":" + + row.getUDTValue(1); // Formats as field:value with no whitespaces, and strings quoted + } + + @Override + protected UpgradeableCluster provisionCluster(TestVersion testVersion) throws IOException + { + // spin up a C* cluster using the in-jvm dtest + Versions versions = Versions.find(); + Versions.Version requestedVersion = versions.getLatest(new Semver(testVersion.version(), Semver.SemverType.LOOSE)); + + UpgradeableCluster.Builder clusterBuilder = + UpgradeableCluster.build(3) + .withDynamicPortAllocation(true) + .withVersion(requestedVersion) + .withDCs(1) + .withDataDirCount(1) + .withSharedClasses(JvmDTestSharedClassesPredicate.INSTANCE) + .withConfig(config -> config.with(Feature.NATIVE_PROTOCOL) + .with(Feature.GOSSIP) + .with(Feature.JMX)); + TokenSupplier tokenSupplier = TokenSupplier.evenlyDistributedTokens(3, clusterBuilder.getTokenCount()); + clusterBuilder.withTokenSupplier(tokenSupplier); + UpgradeableCluster cluster = clusterBuilder.createWithoutStarting(); + cluster.startup(); + + waitForHealthyRing(cluster); + fixDistributedSchemas(cluster); + return cluster; + } + + @Override + protected void initializeSchemaForTest() + { + createTestKeyspace(UDT_TABLE_NAME, DC1_RF3); + + cluster.schemaChangeIgnoringStoppedInstances(TWO_FIELD_UDT_DEF); + cluster.schemaChangeIgnoringStoppedInstances(NESTED_UDT_DEF); + cluster.schemaChangeIgnoringStoppedInstances(UDT_TABLE_CREATE); + cluster.schemaChangeIgnoringStoppedInstances(NESTED_TABLE_CREATE); + } +} diff --git a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/DataGenerationUtils.java b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/DataGenerationUtils.java index a0882c5..3c1897d 100644 --- a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/DataGenerationUtils.java +++ b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/DataGenerationUtils.java @@ -31,11 +31,14 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.types.DataTypes.IntegerType; import static org.apache.spark.sql.types.DataTypes.LongType; import static org.apache.spark.sql.types.DataTypes.StringType; +import static org.apache.spark.sql.types.DataTypes.createStructType; /** * Utilities for data generation used for tests @@ -49,6 +52,7 @@ public final class DataGenerationUtils /** * Generates course data with schema + * Does not generate a User Defined Field * * <pre> * id integer, @@ -61,18 +65,48 @@ public final class DataGenerationUtils * @return a {@link Dataset} with generated data */ public static Dataset<Row> generateCourseData(SparkSession spark, int rowCount) + { + return generateCourseData(spark, rowCount, false); + } + + /** + * Generates course data with schema + * + * <pre> + * id integer, + * course string, + * marks integer + * </pre> + * + * @param spark the spark session to use + * @param rowCount the number of records to generate + * @param udfData if a field representing a User Defined Type should be added + * @return a {@link Dataset} with generated data + */ + public static Dataset<Row> generateCourseData(SparkSession spark, int rowCount, boolean udfData) { SQLContext sql = spark.sqlContext(); StructType schema = new StructType() .add("id", IntegerType, false) .add("course", StringType, false) .add("marks", IntegerType, false); + if (udfData) + { + StructType udfType = new StructType() + .add("TimE", IntegerType, false) + .add("limit", IntegerType, false); + schema = schema.add("User_Defined_Type", udfType); + } List<Row> rows = IntStream.range(0, rowCount) .mapToObj(recordNum -> { String course = "course" + recordNum; - Object[] values = {recordNum, course, recordNum}; - return RowFactory.create(values); + if (!udfData) + { + return RowFactory.create(recordNum, course, recordNum); + } + return RowFactory.create(recordNum, course, recordNum, + RowFactory.create(recordNum, recordNum)); }).collect(Collectors.toList()); return sql.createDataFrame(rows, schema); } @@ -115,6 +149,45 @@ public final class DataGenerationUtils return sql.createDataFrame(rows, schema); } + public static Dataset<Row> generateUdtData(SparkSession spark, int rowCount) + { + SQLContext sql = spark.sqlContext(); + StructType udtType = createStructType(new StructField[]{new StructField("f1", StringType, false, Metadata.empty()), + new StructField("f2", IntegerType, false, Metadata.empty())}); + StructType schema = new StructType() + .add("id", IntegerType, false) + .add("udtfield", udtType, false); + + List<Row> rows = IntStream.range(0, rowCount) + .mapToObj(id -> { + String course = "course" + id; + Object[] values = {id, RowFactory.create(course, id)}; + return RowFactory.create(values); + }).collect(Collectors.toList()); + return sql.createDataFrame(rows, schema); + } + + public static Dataset<Row> generateNestedUdtData(SparkSession spark, int rowCount) + { + SQLContext sql = spark.sqlContext(); + StructType udtType = createStructType(new StructField[]{new StructField("f1", StringType, false, Metadata.empty()), + new StructField("f2", IntegerType, false, Metadata.empty())}); + StructType nestedType = createStructType(new StructField[] {new StructField("n1", IntegerType, false, Metadata.empty()), + new StructField("n2", udtType, false, Metadata.empty())}); + StructType schema = new StructType() + .add("id", IntegerType, false) + .add("nested", nestedType, false); + + List<Row> rows = IntStream.range(0, rowCount) + .mapToObj(id -> { + String course = "course" + id; + Row innerUdt = RowFactory.create(id, RowFactory.create(course, id)); + Object[] values = {id, innerUdt}; + return RowFactory.create(values); + }).collect(Collectors.toList()); + return sql.createDataFrame(rows, schema); + } + private static String dupString(String string, Integer times) { byte[] stringBytes = string.getBytes(); diff --git a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/QuoteIdentifiersWriteTest.java b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/QuoteIdentifiersWriteTest.java index 2650cd7..7cd47dd 100644 --- a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/QuoteIdentifiersWriteTest.java +++ b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/QuoteIdentifiersWriteTest.java @@ -40,9 +40,10 @@ import org.apache.cassandra.testing.TestVersion; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.jetbrains.annotations.NotNull; +import shaded.com.datastax.driver.core.ConsistencyLevel; import static org.apache.cassandra.analytics.DataGenerationUtils.generateCourseData; -import static org.apache.cassandra.analytics.SparkTestUtils.validateWrites; import static org.apache.cassandra.testing.TestUtils.DC1_RF1; import static org.apache.cassandra.testing.TestUtils.ROW_COUNT; import static org.apache.cassandra.testing.TestUtils.TEST_KEYSPACE; @@ -56,11 +57,13 @@ import static org.apache.cassandra.testing.TestUtils.uniqueTestTableFullName; */ class QuoteIdentifiersWriteTest extends SharedClusterSparkIntegrationTestBase { + static final QualifiedName TABLE_NAME_FOR_UDT_TEST = uniqueTestTableFullName("QuOtEd_KeYsPaCe", "QuOtEd_TaBlE"); static final List<QualifiedName> TABLE_NAMES = Arrays.asList(uniqueTestTableFullName("QuOtEd_KeYsPaCe"), uniqueTestTableFullName("keyspace"), // keyspace is a reserved word uniqueTestTableFullName(TEST_KEYSPACE, "QuOtEd_TaBlE"), - new QualifiedName(TEST_KEYSPACE, "table")); // table is a reserved word + new QualifiedName(TEST_KEYSPACE, "table"), // table is a reserved word + TABLE_NAME_FOR_UDT_TEST); @ParameterizedTest(name = "{index} => table={0}") @MethodSource("testInputs") @@ -69,12 +72,51 @@ class QuoteIdentifiersWriteTest extends SharedClusterSparkIntegrationTestBase SparkSession spark = getOrCreateSparkSession(); // Generates course data from and renames the dataframe columns to use case-sensitive and reserved // words in the dataframe - Dataset<Row> df = generateCourseData(spark, ROW_COUNT).toDF("IdEnTiFiEr", // case-sensitive struct - "course", - "limit"); // limit is a reserved word in Cassandra + boolean udfData = tableName.equals(TABLE_NAME_FOR_UDT_TEST); + Dataset<Row> df; + Dataset<Row> generatedDf = generateCourseData(spark, ROW_COUNT, udfData); + if (!udfData) + { + df = generatedDf.toDF("IdEnTiFiEr", // case-sensitive struct + "course", + "limit"); // limit is a reserved word in Cassandra + } + else + { + df = generatedDf.toDF("IdEnTiFiEr", // case-sensitive struct + "course", + "limit", // limit is a reserved word in Cassandra + "User_Defined_Type"); + } bulkWriterDataFrameWriter(df, tableName).option(WriterOptions.QUOTE_IDENTIFIERS.name(), "true") .save(); - validateWrites(df.collectAsList(), queryAllData(tableName)); + validateWritesWithDriverResultSet(df.collectAsList(), + queryAllDataWithDriver(cluster, tableName, + ConsistencyLevel.LOCAL_QUORUM), + udfData ? + QuoteIdentifiersWriteTest::rowWithUdtFormatter : + QuoteIdentifiersWriteTest::defaultRowFormatter); + } + + public static String defaultRowFormatter(shaded.com.datastax.driver.core.Row row) + { + return row.getInt("IdEnTiFiEr") + + ":'" + + row.getString("course") + + "':" + + row.getInt("limit"); + } + + @NotNull + private static String rowWithUdtFormatter(shaded.com.datastax.driver.core.Row row) + { + return row.getInt("IdEnTiFiEr") + + ":'" + + row.getString("course") + + "':" + + row.getInt("limit") + + ":" + + row.getUDTValue("User_Defined_Type"); } static Stream<Arguments> testInputs() @@ -115,7 +157,22 @@ class QuoteIdentifiersWriteTest extends SharedClusterSparkIntegrationTestBase TABLE_NAMES.forEach(name -> { createTestKeyspace(name, DC1_RF1); - createTestTable(name, createTableStatement); + if (!name.equals(TABLE_NAME_FOR_UDT_TEST)) + { + createTestTable(name, createTableStatement); + } }); + + // Create UDT + String createUdtQuery = "CREATE TYPE " + TABLE_NAME_FOR_UDT_TEST.maybeQuotedKeyspace() + + ".\"UdT1\" (\"TimE\" bigint, \"limit\" int);"; + cluster.schemaChangeIgnoringStoppedInstances(createUdtQuery); + + createTestTable(TABLE_NAME_FOR_UDT_TEST, "CREATE TABLE IF NOT EXISTS %s (" + + "\"IdEnTiFiEr\" int, " + + "course text, " + + "\"limit\" int," + + "\"User_Defined_Type\" frozen<\"UdT1\">, " + + "PRIMARY KEY(\"IdEnTiFiEr\"));"); } } diff --git a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java index 699f705..b0807d8 100644 --- a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java +++ b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/SharedClusterSparkIntegrationTestBase.java @@ -20,6 +20,10 @@ package org.apache.cassandra.analytics; import java.net.UnknownHostException; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -27,7 +31,10 @@ import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; import org.junit.jupiter.api.extension.ExtendWith; +import com.vdurmont.semver4j.Semver; import io.vertx.junit5.VertxExtension; +import org.apache.cassandra.bridge.CassandraBridge; +import org.apache.cassandra.bridge.CassandraBridgeFactory; import org.apache.cassandra.distributed.shared.JMXUtil; import org.apache.cassandra.sidecar.testing.QualifiedName; import org.apache.cassandra.sidecar.testing.SharedClusterIntegrationTestBase; @@ -37,10 +44,13 @@ import org.apache.spark.sql.DataFrameWriter; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.StructField; +import shaded.com.datastax.driver.core.ResultSet; import static org.apache.cassandra.analytics.SparkTestUtils.defaultBulkReaderDataFrame; import static org.apache.cassandra.analytics.SparkTestUtils.defaultBulkWriterDataFrameWriter; import static org.apache.cassandra.analytics.SparkTestUtils.defaultSparkConf; +import static org.assertj.core.api.Assertions.assertThat; /** * Extends functionality from {@link SharedClusterIntegrationTestBase} and provides additional functionality for running @@ -52,6 +62,23 @@ public abstract class SharedClusterSparkIntegrationTestBase extends SharedCluste { protected SparkConf sparkConf; protected SparkSession sparkSession; + protected CassandraBridge bridge; + + public void validateWritesWithDriverResultSet(List<Row> sourceData, ResultSet queriedData, + Function<shaded.com.datastax.driver.core.Row, String> rowFormatter) + { + Set<String> actualEntries = new HashSet<>(); + queriedData.forEach(row -> actualEntries.add(rowFormatter.apply(row))); + + // Number of entries in Cassandra must match the original datasource + assertThat(actualEntries.size()).isEqualTo(sourceData.size()); + + // remove from actual entries to make sure that the data read is the same as the data written + Set<String> sourceEntries = sourceData.stream().map(this::getFormattedSourceEntry) + .collect(Collectors.toSet()); + assertThat(actualEntries).as("All entries are expected to be read from database") + .containsExactlyInAnyOrderElementsOf(sourceEntries); + } /** * A preconfigured {@link DataFrameReader} with pre-populated required options that can be overridden @@ -122,4 +149,65 @@ public abstract class SharedClusterSparkIntegrationTestBase extends SharedCluste } return sparkSession; } + + protected CassandraBridge getOrCreateBridge() + { + if (bridge == null) + { + Semver semVer = new Semver(testVersion.version(), + Semver.SemverType.LOOSE); + bridge = CassandraBridgeFactory.get(semVer.toStrict().toString()); + } + return bridge; + } + + private String getFormattedSourceEntry(Row row) + { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < row.size(); i++) + { + maybeFormatUdt(sb, row.get(i)); + if (i != (row.size() - 1)) + { + sb.append(":"); + } + } + return sb.toString(); + } + + // Format a Spark row to look like what the toString on a UDT looks like + // Unfortunately not _quite_ json, so we need to do this manually. + protected void maybeFormatUdt(StringBuilder sb, Object o) + { + if (o instanceof Row) + { + Row r = (Row) o; + sb.append("{"); + StructField[] fields = r.schema().fields(); + for (int i = 0; i < r.size(); i++) + { + sb.append(maybeQuoteFieldName(fields[i])); + sb.append(":"); + maybeFormatUdt(sb, r.get(i)); + if (i != r.size() - 1) + { + sb.append(','); + } + } + sb.append("}"); + } + else if (o instanceof String) + { + sb.append(String.format("'%s'", o)); + } + else + { + sb.append(String.format("%s", o)); + } + } + + protected String maybeQuoteFieldName(StructField fields) + { + return getOrCreateBridge().maybeQuoteIdentifier(fields.name()); + } } diff --git a/cassandra-bridge/src/main/java/org/apache/cassandra/bridge/CassandraBridge.java b/cassandra-bridge/src/main/java/org/apache/cassandra/bridge/CassandraBridge.java index c89f2cd..b234065 100644 --- a/cassandra-bridge/src/main/java/org/apache/cassandra/bridge/CassandraBridge.java +++ b/cassandra-bridge/src/main/java/org/apache/cassandra/bridge/CassandraBridge.java @@ -388,6 +388,7 @@ public abstract class CassandraBridge String partitioner, String createStatement, String insertStatement, + Set<String> userDefinedTypeStatements, int bufferSizeMB); public interface IRow diff --git a/cassandra-bridge/src/main/java/org/apache/cassandra/spark/data/BridgeUdtValue.java b/cassandra-bridge/src/main/java/org/apache/cassandra/spark/data/BridgeUdtValue.java new file mode 100644 index 0000000..a6d1215 --- /dev/null +++ b/cassandra-bridge/src/main/java/org/apache/cassandra/spark/data/BridgeUdtValue.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.cassandra.spark.data; + +import java.io.Serializable; +import java.util.Map; +import java.util.Objects; + +/** + * The BridgeUdtValue class exists because the Cassandra values produced (UDTValue) are not serializable + * because they come from the classloader inside the bridge, and therefore can't be passed around + * from one Spark phase to another. Therefore, we build a map of these instances (potentially nested) + * and return them from the conversion stage for later use when the writer actually writes them. + */ +public class BridgeUdtValue implements Serializable +{ + public final String name; + public final Map<String, Object> udtMap; + + public BridgeUdtValue(String name, Map<String, Object> valueMap) + { + this.name = name; + this.udtMap = valueMap; + } + + public boolean equals(Object o) + { + if (this == o) + { + return true; + } + if (o == null || getClass() != o.getClass()) + { + return false; + } + BridgeUdtValue udtValue = (BridgeUdtValue) o; + return Objects.equals(name, udtValue.name) && Objects.equals(udtMap, udtValue.udtMap); + } + + public int hashCode() + { + return Objects.hash(name, udtMap); + } + + public String toString() + { + return "BridgeUdtValue{" + + "name='" + name + '\'' + + ", udtMap=" + udtMap + + '}'; + } +} diff --git a/cassandra-four-zero/src/main/java/org/apache/cassandra/bridge/CassandraBridgeImplementation.java b/cassandra-four-zero/src/main/java/org/apache/cassandra/bridge/CassandraBridgeImplementation.java index 56e6d46..0aa576e 100644 --- a/cassandra-four-zero/src/main/java/org/apache/cassandra/bridge/CassandraBridgeImplementation.java +++ b/cassandra-four-zero/src/main/java/org/apache/cassandra/bridge/CassandraBridgeImplementation.java @@ -592,9 +592,11 @@ public class CassandraBridgeImplementation extends CassandraBridge String partitioner, String createStatement, String insertStatement, + @NotNull Set<String> userDefinedTypeStatements, int bufferSizeMB) { - return new SSTableWriterImplementation(inDirectory, partitioner, createStatement, insertStatement, bufferSizeMB); + return new SSTableWriterImplementation(inDirectory, partitioner, createStatement, insertStatement, + userDefinedTypeStatements, bufferSizeMB); } // Version-Specific Test Utility Methods diff --git a/cassandra-four-zero/src/main/java/org/apache/cassandra/bridge/SSTableWriterImplementation.java b/cassandra-four-zero/src/main/java/org/apache/cassandra/bridge/SSTableWriterImplementation.java index 0a7ecde..fdc9cab 100644 --- a/cassandra-four-zero/src/main/java/org/apache/cassandra/bridge/SSTableWriterImplementation.java +++ b/cassandra-four-zero/src/main/java/org/apache/cassandra/bridge/SSTableWriterImplementation.java @@ -21,6 +21,7 @@ package org.apache.cassandra.bridge; import java.io.IOException; import java.util.Map; +import java.util.Set; import com.google.common.annotations.VisibleForTesting; @@ -30,6 +31,7 @@ import org.apache.cassandra.dht.Murmur3Partitioner; import org.apache.cassandra.dht.RandomPartitioner; import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.io.sstable.CQLSSTableWriter; +import org.jetbrains.annotations.NotNull; public class SSTableWriterImplementation implements SSTableWriter { @@ -44,6 +46,7 @@ public class SSTableWriterImplementation implements SSTableWriter String partitioner, String createStatement, String insertStatement, + @NotNull Set<String> userDefinedTypeStatements, int bufferSizeMB) { IPartitioner cassPartitioner = partitioner.toLowerCase().contains("random") ? new RandomPartitioner() @@ -53,6 +56,7 @@ public class SSTableWriterImplementation implements SSTableWriter createStatement, insertStatement, bufferSizeMB, + userDefinedTypeStatements, cassPartitioner); writer = builder.build(); } @@ -81,17 +85,23 @@ public class SSTableWriterImplementation implements SSTableWriter String createStatement, String insertStatement, int bufferSizeMB, + Set<String> udts, IPartitioner cassPartitioner) { - return CQLSSTableWriter - .builder() - .inDirectory(inDirectory) - .forTable(createStatement) - .withPartitioner(cassPartitioner) - .using(insertStatement) - // The data frame to write is always sorted, - // see org.apache.cassandra.spark.bulkwriter.CassandraBulkSourceRelation.insert - .sorted() - .withMaxSSTableSizeInMiB(bufferSizeMB); + CQLSSTableWriter.Builder builder = CQLSSTableWriter.builder(); + + for (String udt : udts) + { + builder.withType(udt); + } + + return builder.inDirectory(inDirectory) + .forTable(createStatement) + .withPartitioner(cassPartitioner) + .using(insertStatement) + // The data frame to write is always sorted, + // see org.apache.cassandra.spark.bulkwriter.CassandraBulkSourceRelation.insert + .sorted() + .withMaxSSTableSizeInMiB(bufferSizeMB); } } diff --git a/cassandra-four-zero/src/main/java/org/apache/cassandra/spark/data/complex/CqlUdt.java b/cassandra-four-zero/src/main/java/org/apache/cassandra/spark/data/complex/CqlUdt.java index e97ce6c..6000af0 100644 --- a/cassandra-four-zero/src/main/java/org/apache/cassandra/spark/data/complex/CqlUdt.java +++ b/cassandra-four-zero/src/main/java/org/apache/cassandra/spark/data/complex/CqlUdt.java @@ -162,6 +162,10 @@ public class CqlUdt extends CqlType implements CqlField.CqlUdt @Override public Object convertForCqlWriter(Object value, CassandraVersion version) { + if (value instanceof UDTValue) + { + return value; + } return toUserTypeValue(version, this, value); } diff --git a/cassandra-four-zero/src/test/java/org/apache/cassandra/bridge/SSTableWriterImplementationTest.java b/cassandra-four-zero/src/test/java/org/apache/cassandra/bridge/SSTableWriterImplementationTest.java index e5e5846..589ee4e 100644 --- a/cassandra-four-zero/src/test/java/org/apache/cassandra/bridge/SSTableWriterImplementationTest.java +++ b/cassandra-four-zero/src/test/java/org/apache/cassandra/bridge/SSTableWriterImplementationTest.java @@ -22,6 +22,7 @@ package org.apache.cassandra.bridge; import java.io.File; import java.lang.reflect.Field; import java.util.Arrays; +import java.util.HashSet; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; @@ -51,6 +52,7 @@ class SSTableWriterImplementationTest CREATE_STATEMENT, INSERT_STATEMENT, 250, + new HashSet<>(), new Murmur3Partitioner()); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org For additional commands, e-mail: commits-h...@cassandra.apache.org