yifan-c commented on code in PR #45: URL: https://github.com/apache/cassandra-analytics/pull/45#discussion_r1540148669
########## cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraSchemaInfo.java: ########## @@ -19,19 +19,33 @@ package org.apache.cassandra.spark.bulkwriter; +import java.util.Set; + +import org.apache.cassandra.spark.data.CqlTable; + public class CassandraSchemaInfo implements SchemaInfo { private static final long serialVersionUID = -2327383232935001862L; private final TableSchema tableSchema; + private final Set<String> userDefinedTypeStatements; + private final CqlTable cqlTable; Review Comment: variable not exposed and not used anywhere ########## cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SqlToCqlTypeConverter.java: ########## @@ -788,11 +797,54 @@ else if (object instanceof scala.Tuple2) } return map; } + } + + public static class UdtConverter extends NullableConverter<Object> + { + private final String name; + private final HashMap<Object, Converter<?>> converters; + private CqlField.CqlUdt udt; Review Comment: Unused. ########## cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SqlToCqlTypeConverter.java: ########## @@ -788,11 +797,54 @@ else if (object instanceof scala.Tuple2) } return map; } + } + + public static class UdtConverter extends NullableConverter<Object> + { + private final String name; + private final HashMap<Object, Converter<?>> converters; Review Comment: Key is `String` ########## cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SqlToCqlTypeConverter.java: ########## @@ -788,11 +797,54 @@ else if (object instanceof scala.Tuple2) } return map; } + } + + public static class UdtConverter extends NullableConverter<Object> Review Comment: Why it converts to `Object` instead of `BridgeUdtValue`? ########## cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/TableSchemaNormalizeTest.java: ########## @@ -296,6 +303,24 @@ public void testNestedNormalization() DataTypes.createArrayType(DataTypes.createArrayType(DataTypes.BinaryType)))); } + @Test + public void testUdtNormalization() + { + StructType structType = new StructType() + .add(new StructField("f1", DataTypes.IntegerType, false, Metadata.empty())) + .add(new StructField("f2", DataTypes.StringType, false, Metadata.empty())); + + GenericRowWithSchema source = new GenericRowWithSchema(new Object[]{1, "course"}, structType); + // NOTE: UDT Types cary their type name around, so the use of `udt_field` consistently here is a bit Review Comment: `carry`? ########## cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java: ########## @@ -108,6 +119,26 @@ private String getStreamId(TaskContext taskContext) return String.format("%d-%s", taskContext.partitionId(), UUID.randomUUID()); } + private CqlTable cqlTable() Review Comment: The only callsite of `cqlTable()` is already `synchronized` on record writer. I think you can remove the double-checked locking entirely, if keeping the synchronization at `getUdt()`. ########## cassandra-four-zero/src/main/java/org/apache/cassandra/bridge/SSTableWriterImplementation.java: ########## @@ -81,17 +85,23 @@ static CQLSSTableWriter.Builder configureBuilder(String inDirectory, String createStatement, String insertStatement, int bufferSizeMB, + Set<String> udts, IPartitioner cassPartitioner) { - return CQLSSTableWriter - .builder() - .inDirectory(inDirectory) - .forTable(createStatement) - .withPartitioner(cassPartitioner) - .using(insertStatement) - // The data frame to write is always sorted, - // see org.apache.cassandra.spark.bulkwriter.CassandraBulkSourceRelation.insert - .sorted() - .withMaxSSTableSizeInMiB(bufferSizeMB); + CQLSSTableWriter.Builder builder = CQLSSTableWriter + .builder() + .inDirectory(inDirectory) + .forTable(createStatement) + .withPartitioner(cassPartitioner) + .using(insertStatement) + // The data frame to write is always sorted, + // see org.apache.cassandra.spark.bulkwriter.CassandraBulkSourceRelation.insert + .sorted() + .withMaxSSTableSizeInMiB(bufferSizeMB); + for (String udt : udts) + { + builder.withType(udt); + } + return builder; Review Comment: nit ```suggestion CQLSSTableWriter.Builder builder = CQLSSTableWriter.builder(); for (String udt : udts) { builder.withType(udt); } return builder .inDirectory(inDirectory) .forTable(createStatement) .withPartitioner(cassPartitioner) .using(insertStatement) // The data frame to write is always sorted, // see org.apache.cassandra.spark.bulkwriter.CassandraBulkSourceRelation.insert .sorted() .withMaxSSTableSizeInMiB(bufferSizeMB); ``` ########## cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java: ########## @@ -95,6 +105,7 @@ public RecordWriter(BulkWriterContext writerContext, String[] columnNames) this.writeValidator = new BulkWriteValidator(writerContext, failureHandler); this.digestAlgorithm = this.writerContext.job().digestAlgorithmSupplier().get(); + Review Comment: nit: remove the new empty line ########## cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SqlToCqlTypeConverter.java: ########## @@ -763,6 +766,12 @@ else if (object instanceof Map) throw new RuntimeException("Unsupported conversion for MAP from " + object.getClass().getTypeName()); } + @Override + public String toString() + { + return "Map"; Review Comment: How about the richer string? ```suggestion return "Map<" + keyConverter.toString() + ", " + valConverter.toString() + '>'; ``` ########## cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/BulkWriteUdtTest.java: ########## @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.cassandra.analytics; + +import java.io.IOException; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; + +import com.vdurmont.semver4j.Semver; +import org.apache.cassandra.distributed.UpgradeableCluster; +import org.apache.cassandra.distributed.api.ConsistencyLevel; +import org.apache.cassandra.distributed.api.Feature; +import org.apache.cassandra.distributed.api.SimpleQueryResult; +import org.apache.cassandra.distributed.api.TokenSupplier; +import org.apache.cassandra.distributed.shared.Versions; +import org.apache.cassandra.sidecar.testing.JvmDTestSharedClassesPredicate; +import org.apache.cassandra.sidecar.testing.QualifiedName; +import org.apache.cassandra.testing.TestVersion; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.StructField; +import shaded.com.datastax.driver.core.ResultSet; + +import static org.apache.cassandra.testing.CassandraTestTemplate.fixDistributedSchemas; +import static org.apache.cassandra.testing.CassandraTestTemplate.waitForHealthyRing; +import static org.apache.cassandra.testing.TestUtils.DC1_RF3; +import static org.apache.cassandra.testing.TestUtils.ROW_COUNT; +import static org.apache.cassandra.testing.TestUtils.TEST_KEYSPACE; +import static org.assertj.core.api.Assertions.assertThat; + +class BulkWriteUdtTest extends SharedClusterSparkIntegrationTestBase +{ + static final QualifiedName UDT_TABLE_NAME = new QualifiedName(TEST_KEYSPACE, "test_udt"); + static final QualifiedName NESTED_TABLE_NAME = new QualifiedName(TEST_KEYSPACE, "test_nested_udt"); + public static final String TWO_FIELD_UDT_NAME = "two_field_udt"; + public static final String NESTED_FIELD_UDT_NAME = "nested_udt"; + public static final String UDT_TABLE_CREATE = "CREATE TABLE " + UDT_TABLE_NAME + " (\n" + + " id BIGINT PRIMARY KEY,\n" + + " udtfield " + TWO_FIELD_UDT_NAME + ");"; + public static final String TWO_FIELD_UDT_DEF = "CREATE TYPE " + UDT_TABLE_NAME.keyspace() + "." + + TWO_FIELD_UDT_NAME + " (\n" + + " f1 text,\n" + + " f2 int);"; + public static final String NESTED_UDT_DEF = "CREATE TYPE " + NESTED_TABLE_NAME.keyspace() + "." + + NESTED_FIELD_UDT_NAME + " (\n" + + " n1 BIGINT,\n" + + " n2 frozen<" + TWO_FIELD_UDT_NAME + ">" + + ");"; + public static final String NESTED_TABLE_CREATE = "CREATE TABLE " + NESTED_TABLE_NAME + "(\n" + + " id BIGINT PRIMARY KEY,\n" + + " nested " + NESTED_FIELD_UDT_NAME + ");"; + @Test + void testWriteWithUdt() + { + SparkSession spark = getOrCreateSparkSession(); + Dataset<Row> df = DataGenerationUtils.generateUdtData(spark, ROW_COUNT); + + bulkWriterDataFrameWriter(df, UDT_TABLE_NAME).save(); + + SimpleQueryResult result = cluster.coordinator(1).executeWithResult("SELECT * FROM " + UDT_TABLE_NAME, ConsistencyLevel.ALL); + assertThat(result.hasNext()).isTrue(); + validateWrites(df.collectAsList(), queryAllDataWithDriver(cluster, UDT_TABLE_NAME, shaded.com.datastax.driver.core.ConsistencyLevel.LOCAL_QUORUM)); + } + + @Test + void testWriteWithNestedUdt() + { + SparkSession spark = getOrCreateSparkSession(); + Dataset<Row> df = DataGenerationUtils.generateNestedUdtData(spark, ROW_COUNT); + + bulkWriterDataFrameWriter(df, NESTED_TABLE_NAME).save(); + + SimpleQueryResult result = cluster.coordinator(1).executeWithResult("SELECT * FROM " + NESTED_TABLE_NAME, ConsistencyLevel.ALL); + assertThat(result.hasNext()).isTrue(); + validateWrites(df.collectAsList(), queryAllDataWithDriver(cluster, NESTED_TABLE_NAME, shaded.com.datastax.driver.core.ConsistencyLevel.LOCAL_QUORUM)); + } + + @Override + protected UpgradeableCluster provisionCluster(TestVersion testVersion) throws IOException + { + // spin up a C* cluster using the in-jvm dtest + Versions versions = Versions.find(); + Versions.Version requestedVersion = versions.getLatest(new Semver(testVersion.version(), Semver.SemverType.LOOSE)); + + UpgradeableCluster.Builder clusterBuilder = + UpgradeableCluster.build(3) + .withDynamicPortAllocation(true) + .withVersion(requestedVersion) + .withDCs(1) + .withDataDirCount(1) + .withSharedClasses(JvmDTestSharedClassesPredicate.INSTANCE) + .withConfig(config -> config.with(Feature.NATIVE_PROTOCOL) + .with(Feature.GOSSIP) + .with(Feature.JMX)); + TokenSupplier tokenSupplier = TokenSupplier.evenlyDistributedTokens(3, clusterBuilder.getTokenCount()); + clusterBuilder.withTokenSupplier(tokenSupplier); + UpgradeableCluster cluster = clusterBuilder.createWithoutStarting(); + cluster.startup(); + + waitForHealthyRing(cluster); + fixDistributedSchemas(cluster); + return cluster; + } + + @Override + protected void initializeSchemaForTest() + { + createTestKeyspace(UDT_TABLE_NAME, DC1_RF3); + + cluster.schemaChangeIgnoringStoppedInstances(TWO_FIELD_UDT_DEF); + cluster.schemaChangeIgnoringStoppedInstances(NESTED_UDT_DEF); + cluster.schemaChangeIgnoringStoppedInstances(UDT_TABLE_CREATE); + cluster.schemaChangeIgnoringStoppedInstances(NESTED_TABLE_CREATE); + } + + + public static void validateWrites(List<Row> sourceData, ResultSet queriedData) + { + Set<String> actualEntries = new HashSet<>(); + queriedData.forEach(row -> + actualEntries.add(getFormattedData(row))); + + // Number of entries in Cassandra must match the original datasource + assertThat(actualEntries.size()).isEqualTo(sourceData.size()); + + // remove from actual entries to make sure that the data read is the same as the data written + Set<String> sourceEntries = sourceData.stream().map(BulkWriteUdtTest::getFormatedSourceEntry) + .collect(Collectors.toSet()); + assertThat(actualEntries).as("All entries are expected to be read from database") + .containsExactlyInAnyOrderElementsOf(sourceEntries); + } + + private static String getFormatedSourceEntry(Row row) Review Comment: typo: `getFormattedSourceEntry` (double t) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org For additional commands, e-mail: commits-h...@cassandra.apache.org