This is an automated email from the ASF dual-hosted git repository. reuvenlax pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new e039ca2 Merge pull request #14136: [BEAM-11648] Add conversion utilities for BigQuery Storage API sink e039ca2 is described below commit e039ca28d6f806f30b87cae82e6af86694c171cd Author: reuvenlax <re...@google.com> AuthorDate: Fri Mar 5 08:07:41 2021 -0800 Merge pull request #14136: [BEAM-11648] Add conversion utilities for BigQuery Storage API sink --- .../io/gcp/bigquery/BeamRowToStorageApiProto.java | 336 +++++++++++ .../beam/sdk/io/gcp/bigquery/BigQueryUtils.java | 221 +------ .../beam/sdk/io/gcp/bigquery/CivilTimeEncoder.java | 648 +++++++++++++++++++++ .../io/gcp/bigquery/TableRowToStorageApiProto.java | 304 ++++++++++ .../gcp/bigquery/BeamRowToStorageApiProtoTest.java | 391 +++++++++++++ .../sdk/io/gcp/bigquery/BigQueryUtilsTest.java | 260 --------- .../bigquery/TableRowToStorageApiProtoTest.java | 312 ++++++++++ 7 files changed, 1993 insertions(+), 479 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java new file mode 100644 index 0000000..816cbe9 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import com.google.protobuf.ByteString; +import com.google.protobuf.DescriptorProtos.DescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Label; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type; +import com.google.protobuf.DescriptorProtos.FileDescriptorProto; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.DescriptorValidationException; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.DynamicMessage; +import java.math.BigDecimal; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; +import javax.annotation.Nullable; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.Schema.LogicalType; +import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Functions; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Bytes; +import org.joda.time.ReadableInstant; + +/** + * Utility methods for converting Beam {@link Row} objects to dynamic protocol message, for use with + * the Storage write API. + */ +public class BeamRowToStorageApiProto { + // Number of digits after the decimal point supported by the NUMERIC data type. + private static final int NUMERIC_SCALE = 9; + // Maximum and minimum allowed values for the NUMERIC data type. + private static final BigDecimal MAX_NUMERIC_VALUE = + new BigDecimal("99999999999999999999999999999.999999999"); + private static final BigDecimal MIN_NUMERIC_VALUE = + new BigDecimal("-99999999999999999999999999999.999999999"); + + // TODO(reuvenlax): Support BIGNUMERIC and GEOGRAPHY types. + static final Map<TypeName, Type> PRIMITIVE_TYPES = + ImmutableMap.<TypeName, Type>builder() + .put(TypeName.INT16, Type.TYPE_INT32) + .put(TypeName.BYTE, Type.TYPE_INT32) + .put(TypeName.INT32, Type.TYPE_INT32) + .put(TypeName.INT64, Type.TYPE_INT64) + .put(TypeName.FLOAT, Type.TYPE_FLOAT) + .put(TypeName.DOUBLE, Type.TYPE_DOUBLE) + .put(TypeName.STRING, Type.TYPE_STRING) + .put(TypeName.BOOLEAN, Type.TYPE_BOOL) + .put(TypeName.DATETIME, Type.TYPE_INT64) + .put(TypeName.BYTES, Type.TYPE_BYTES) + .put(TypeName.DECIMAL, Type.TYPE_BYTES) + .build(); + + // A map of supported logical types to the protobuf field type. + static final Map<String, Type> LOGICAL_TYPES = + ImmutableMap.<String, Type>builder() + .put(SqlTypes.DATE.getIdentifier(), Type.TYPE_INT32) + .put(SqlTypes.TIME.getIdentifier(), Type.TYPE_INT64) + .put(SqlTypes.DATETIME.getIdentifier(), Type.TYPE_INT64) + .put(SqlTypes.TIMESTAMP.getIdentifier(), Type.TYPE_INT64) + .put(EnumerationType.IDENTIFIER, Type.TYPE_STRING) + .build(); + + static final Map<TypeName, Function<Object, Object>> PRIMITIVE_ENCODERS = + ImmutableMap.<TypeName, Function<Object, Object>>builder() + .put(TypeName.INT16, o -> Integer.valueOf((Short) o)) + .put(TypeName.BYTE, o -> Integer.valueOf((Byte) o)) + .put(TypeName.INT32, Functions.identity()) + .put(TypeName.INT64, Functions.identity()) + .put(TypeName.FLOAT, Function.identity()) + .put(TypeName.DOUBLE, Function.identity()) + .put(TypeName.STRING, Function.identity()) + .put(TypeName.BOOLEAN, Function.identity()) + // A Beam DATETIME is actually a timestamp, not a DateTime. + .put(TypeName.DATETIME, o -> ((ReadableInstant) o).getMillis() * 1000) + .put(TypeName.BYTES, o -> ByteString.copyFrom((byte[]) o)) + .put(TypeName.DECIMAL, o -> serializeBigDecimalToNumeric((BigDecimal) o)) + .build(); + + // A map of supported logical types to their encoding functions. + static final Map<String, BiFunction<LogicalType<?, ?>, Object, Object>> LOGICAL_TYPE_ENCODERS = + ImmutableMap.<String, BiFunction<LogicalType<?, ?>, Object, Object>>builder() + .put( + SqlTypes.DATE.getIdentifier(), + (logicalType, value) -> (int) ((LocalDate) value).toEpochDay()) + .put( + SqlTypes.TIME.getIdentifier(), + (logicalType, value) -> CivilTimeEncoder.encodePacked64TimeMicros((LocalTime) value)) + .put( + SqlTypes.DATETIME.getIdentifier(), + (logicalType, value) -> + CivilTimeEncoder.encodePacked64DatetimeSeconds((LocalDateTime) value)) + .put( + SqlTypes.TIMESTAMP.getIdentifier(), + (logicalType, value) -> ((java.time.Instant) value).toEpochMilli() * 1000) + .put( + EnumerationType.IDENTIFIER, + (logicalType, value) -> + ((EnumerationType) logicalType).toString((EnumerationType.Value) value)) + .build(); + + /** + * Given a Beam Schema, returns a protocol-buffer Descriptor that can be used to write data using + * the BigQuery Storage API. + */ + public static Descriptor getDescriptorFromSchema(Schema schema) + throws DescriptorValidationException { + DescriptorProto descriptorProto = descriptorSchemaFromBeamSchema(schema); + FileDescriptorProto fileDescriptorProto = + FileDescriptorProto.newBuilder().addMessageType(descriptorProto).build(); + FileDescriptor fileDescriptor = + FileDescriptor.buildFrom(fileDescriptorProto, new FileDescriptor[0]); + + return Iterables.getOnlyElement(fileDescriptor.getMessageTypes()); + } + + /** + * Given a Beam {@link Row} object, returns a protocol-buffer message that can be used to write + * data using the BigQuery Storage streaming API. + */ + public static DynamicMessage messageFromBeamRow(Descriptor descriptor, Row row) { + Schema beamSchema = row.getSchema(); + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + for (int i = 0; i < row.getFieldCount(); ++i) { + Field beamField = beamSchema.getField(i); + FieldDescriptor fieldDescriptor = + Preconditions.checkNotNull(descriptor.findFieldByName(beamField.getName().toLowerCase())); + @Nullable Object value = messageValueFromRowValue(fieldDescriptor, beamField, i, row); + if (value != null) { + builder.setField(fieldDescriptor, value); + } + } + return builder.build(); + } + + @VisibleForTesting + static DescriptorProto descriptorSchemaFromBeamSchema(Schema schema) { + Preconditions.checkState(schema.getFieldCount() > 0); + DescriptorProto.Builder descriptorBuilder = DescriptorProto.newBuilder(); + // Create a unique name for the descriptor ('-' characters cannot be used). + descriptorBuilder.setName("D" + UUID.randomUUID().toString().replace("-", "_")); + int i = 1; + List<DescriptorProto> nestedTypes = Lists.newArrayList(); + for (Field field : schema.getFields()) { + FieldDescriptorProto.Builder fieldDescriptorProtoBuilder = + fieldDescriptorFromBeamField(field, i++, nestedTypes); + descriptorBuilder.addField(fieldDescriptorProtoBuilder); + } + nestedTypes.forEach(descriptorBuilder::addNestedType); + return descriptorBuilder.build(); + } + + private static FieldDescriptorProto.Builder fieldDescriptorFromBeamField( + Field field, int fieldNumber, List<DescriptorProto> nestedTypes) { + FieldDescriptorProto.Builder fieldDescriptorBuilder = FieldDescriptorProto.newBuilder(); + fieldDescriptorBuilder = fieldDescriptorBuilder.setName(field.getName().toLowerCase()); + fieldDescriptorBuilder = fieldDescriptorBuilder.setNumber(fieldNumber); + + switch (field.getType().getTypeName()) { + case ROW: + @Nullable Schema rowSchema = field.getType().getRowSchema(); + if (rowSchema == null) { + throw new RuntimeException("Unexpected null schema!"); + } + DescriptorProto nested = descriptorSchemaFromBeamSchema(rowSchema); + nestedTypes.add(nested); + fieldDescriptorBuilder = + fieldDescriptorBuilder.setType(Type.TYPE_MESSAGE).setTypeName(nested.getName()); + break; + case ARRAY: + case ITERABLE: + @Nullable FieldType elementType = field.getType().getCollectionElementType(); + if (elementType == null) { + throw new RuntimeException("Unexpected null element type!"); + } + Preconditions.checkState( + !Preconditions.checkNotNull(elementType.getTypeName()).isCollectionType(), + "Nested arrays not supported by BigQuery."); + return fieldDescriptorFromBeamField( + Field.of(field.getName(), elementType), fieldNumber, nestedTypes) + .setLabel(Label.LABEL_REPEATED); + case LOGICAL_TYPE: + @Nullable LogicalType<?, ?> logicalType = field.getType().getLogicalType(); + if (logicalType == null) { + throw new RuntimeException("Unexpected null logical type " + field.getType()); + } + @Nullable Type type = LOGICAL_TYPES.get(logicalType.getIdentifier()); + if (type == null) { + throw new RuntimeException("Unsupported logical type " + field.getType()); + } + fieldDescriptorBuilder = fieldDescriptorBuilder.setType(type); + break; + case MAP: + throw new RuntimeException("Map types not supported by BigQuery."); + default: + @Nullable Type primitiveType = PRIMITIVE_TYPES.get(field.getType().getTypeName()); + if (primitiveType == null) { + throw new RuntimeException("Unsupported type " + field.getType()); + } + fieldDescriptorBuilder = fieldDescriptorBuilder.setType(primitiveType); + } + if (field.getType().getNullable()) { + fieldDescriptorBuilder = fieldDescriptorBuilder.setLabel(Label.LABEL_OPTIONAL); + } else { + fieldDescriptorBuilder = fieldDescriptorBuilder.setLabel(Label.LABEL_REQUIRED); + } + return fieldDescriptorBuilder; + } + + @Nullable + private static Object messageValueFromRowValue( + FieldDescriptor fieldDescriptor, Field beamField, int index, Row row) { + @Nullable Object value = row.getValue(index); + if (value == null) { + if (fieldDescriptor.isOptional()) { + return null; + } else { + throw new IllegalArgumentException( + "Received null value for non-nullable field " + fieldDescriptor.getName()); + } + } + return toProtoValue(fieldDescriptor, beamField.getType(), value); + } + + private static Object toProtoValue( + FieldDescriptor fieldDescriptor, FieldType beamFieldType, Object value) { + switch (beamFieldType.getTypeName()) { + case ROW: + return messageFromBeamRow(fieldDescriptor.getMessageType(), (Row) value); + case ARRAY: + List<Object> list = (List<Object>) value; + @Nullable FieldType arrayElementType = beamFieldType.getCollectionElementType(); + if (arrayElementType == null) { + throw new RuntimeException("Unexpected null element type!"); + } + return list.stream() + .map(v -> toProtoValue(fieldDescriptor, arrayElementType, v)) + .collect(Collectors.toList()); + case ITERABLE: + Iterable<Object> iterable = (Iterable<Object>) value; + @Nullable FieldType iterableElementType = beamFieldType.getCollectionElementType(); + if (iterableElementType == null) { + throw new RuntimeException("Unexpected null element type!"); + } + return StreamSupport.stream(iterable.spliterator(), false) + .map(v -> toProtoValue(fieldDescriptor, iterableElementType, v)) + .collect(Collectors.toList()); + case MAP: + throw new RuntimeException("Map types not supported by BigQuery."); + default: + return scalarToProtoValue(beamFieldType, value); + } + } + + @VisibleForTesting + static Object scalarToProtoValue(FieldType beamFieldType, Object value) { + if (beamFieldType.getTypeName() == TypeName.LOGICAL_TYPE) { + @Nullable LogicalType<?, ?> logicalType = beamFieldType.getLogicalType(); + if (logicalType == null) { + throw new RuntimeException("Unexpectedly null logical type " + beamFieldType); + } + @Nullable + BiFunction<LogicalType<?, ?>, Object, Object> logicalTypeEncoder = + LOGICAL_TYPE_ENCODERS.get(logicalType.getIdentifier()); + if (logicalTypeEncoder == null) { + throw new RuntimeException("Unsupported logical type " + logicalType.getIdentifier()); + } + return logicalTypeEncoder.apply(logicalType, value); + } else { + @Nullable + Function<Object, Object> encoder = PRIMITIVE_ENCODERS.get(beamFieldType.getTypeName()); + if (encoder == null) { + throw new RuntimeException("Unexpected beam type " + beamFieldType); + } + return encoder.apply(value); + } + } + + static ByteString serializeBigDecimalToNumeric(BigDecimal o) { + return serializeBigDecimal(o, NUMERIC_SCALE, MAX_NUMERIC_VALUE, MIN_NUMERIC_VALUE, "Numeric"); + } + + private static ByteString serializeBigDecimal( + BigDecimal v, int scale, BigDecimal maxValue, BigDecimal minValue, String typeName) { + if (v.scale() > scale) { + throw new IllegalArgumentException( + typeName + " scale cannot exceed " + scale + ": " + v.toPlainString()); + } + if (v.compareTo(maxValue) > 0 || v.compareTo(minValue) < 0) { + throw new IllegalArgumentException(typeName + " overflow: " + v.toPlainString()); + } + + byte[] bytes = v.setScale(scale).unscaledValue().toByteArray(); + // NUMERIC/BIGNUMERIC values are serialized as scaled integers in two's complement form in + // little endian + // order. BigInteger requires the same encoding but in big endian order, therefore we must + // reverse the bytes that come from the proto. + Bytes.reverse(bytes); + return ByteString.copyFrom(bytes); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java index e3380ad..4ce40b6 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java @@ -25,31 +25,17 @@ import com.google.api.services.bigquery.model.TableFieldSchema; import com.google.api.services.bigquery.model.TableRow; import com.google.api.services.bigquery.model.TableSchema; import com.google.auto.value.AutoValue; -import com.google.protobuf.ByteString; -import com.google.protobuf.DescriptorProtos.DescriptorProto; -import com.google.protobuf.DescriptorProtos.FieldDescriptorProto; -import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Label; -import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type; -import com.google.protobuf.DescriptorProtos.FileDescriptorProto; -import com.google.protobuf.Descriptors.Descriptor; -import com.google.protobuf.Descriptors.DescriptorValidationException; -import com.google.protobuf.Descriptors.FieldDescriptor; -import com.google.protobuf.Descriptors.FileDescriptor; -import com.google.protobuf.DynamicMessage; -import com.google.protobuf.Message; import java.io.Serializable; import java.math.BigDecimal; import java.nio.ByteBuffer; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.LocalTime; -import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.UUID; import java.util.function.Function; import java.util.stream.IntStream; import org.apache.avro.Conversions; @@ -87,22 +73,6 @@ import org.joda.time.format.DateTimeFormatterBuilder; "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402) }) public class BigQueryUtils { - - /** - * Given a BigQuery TableSchema, returns a protocol-buffer Descriptor that can be used to write - * data using the Vortex streaming API. - */ - public static Descriptor getDescriptorFromTableSchema(TableSchema jsonSchema) - throws DescriptorValidationException { - DescriptorProto descriptorProto = descriptorSchemaFromTableSchema(jsonSchema); - FileDescriptorProto fileDescriptorProto = - FileDescriptorProto.newBuilder().addMessageType(descriptorProto).build(); - FileDescriptor fileDescriptor = - FileDescriptor.buildFrom(fileDescriptorProto, new FileDescriptor[0]); - - return Iterables.getOnlyElement(fileDescriptor.getMessageTypes()); - } - /** Options for how to convert BigQuery data to Beam data. */ @AutoValue public abstract static class ConversionOptions implements Serializable { @@ -254,7 +224,7 @@ public class BigQueryUtils { // TODO: BigQuery code should not be relying on Calcite metadata fields. If so, this belongs // in the SQL package. - private static final Map<String, StandardSQLTypeName> BEAM_TO_BIGQUERY_LOGICAL_MAPPING = + static final Map<String, StandardSQLTypeName> BEAM_TO_BIGQUERY_LOGICAL_MAPPING = ImmutableMap.<String, StandardSQLTypeName>builder() .put(SqlTypes.DATE.getIdentifier(), StandardSQLTypeName.DATE) .put(SqlTypes.TIME.getIdentifier(), StandardSQLTypeName.TIME) @@ -271,7 +241,7 @@ public class BigQueryUtils { * Get the corresponding BigQuery {@link StandardSQLTypeName} for supported Beam {@link * FieldType}. */ - private static StandardSQLTypeName toStandardSQLTypeName(FieldType fieldType) { + static StandardSQLTypeName toStandardSQLTypeName(FieldType fieldType) { StandardSQLTypeName ret; if (fieldType.getTypeName().isLogicalType()) { ret = BEAM_TO_BIGQUERY_LOGICAL_MAPPING.get(fieldType.getLogicalType().getIdentifier()); @@ -434,193 +404,6 @@ public class BigQueryUtils { return fromTableFieldSchema(tableSchema.getFields(), options); } - static DescriptorProto descriptorSchemaFromTableSchema(TableSchema tableSchema) { - return descriptorSchemaFromTableFieldSchemas(tableSchema.getFields()); - } - - static DescriptorProto descriptorSchemaFromTableFieldSchemas( - Iterable<TableFieldSchema> tableFieldSchemas) { - DescriptorProto.Builder descriptorBuilder = DescriptorProto.newBuilder(); - // Create a unique name for the descriptor ('-' characters cannot be used). - descriptorBuilder.setName("D" + UUID.randomUUID().toString().replace("-", "_")); - int i = 1; - for (TableFieldSchema fieldSchema : tableFieldSchemas) { - fieldDescriptorFromTableField(fieldSchema, i++, descriptorBuilder); - } - return descriptorBuilder.build(); - } - - static void fieldDescriptorFromTableField( - TableFieldSchema fieldSchema, int fieldNumber, DescriptorProto.Builder descriptorBuilder) { - FieldDescriptorProto.Builder fieldDescriptorBuilder = FieldDescriptorProto.newBuilder(); - fieldDescriptorBuilder = fieldDescriptorBuilder.setName(fieldSchema.getName()); - fieldDescriptorBuilder = fieldDescriptorBuilder.setNumber(fieldNumber); - switch (fieldSchema.getType()) { - case "STRING": - fieldDescriptorBuilder = fieldDescriptorBuilder.setType(Type.TYPE_STRING); - break; - case "BYTES": - fieldDescriptorBuilder = fieldDescriptorBuilder.setType(Type.TYPE_BYTES); - break; - case "INT64": - case "INTEGER": - fieldDescriptorBuilder = fieldDescriptorBuilder.setType(Type.TYPE_INT64); - break; - case "FLOAT64": - case "FLOAT": - fieldDescriptorBuilder = fieldDescriptorBuilder.setType(Type.TYPE_FLOAT); - break; - case "BOOL": - case "BOOLEAN": - fieldDescriptorBuilder = fieldDescriptorBuilder.setType(Type.TYPE_BOOL); - break; - case "TIMESTAMP": - case "TIME": - case "DATETIME": - fieldDescriptorBuilder = fieldDescriptorBuilder.setType(Type.TYPE_INT64); - break; - case "DATE": - fieldDescriptorBuilder = fieldDescriptorBuilder.setType(Type.TYPE_INT32); - break; - case "STRUCT": - case "RECORD": - DescriptorProto nested = descriptorSchemaFromTableFieldSchemas(fieldSchema.getFields()); - descriptorBuilder.addNestedType(nested); - fieldDescriptorBuilder = - fieldDescriptorBuilder.setType(Type.TYPE_MESSAGE).setTypeName(nested.getName()); - break; - default: - throw new UnsupportedOperationException( - "Converting BigQuery type " + fieldSchema.getType() + " to Beam type is unsupported"); - } - - Optional<Mode> fieldMode = Optional.ofNullable(fieldSchema.getMode()).map(Mode::valueOf); - if (fieldMode.filter(m -> m == Mode.REPEATED).isPresent()) { - fieldDescriptorBuilder = fieldDescriptorBuilder.setLabel(Label.LABEL_REPEATED); - } else if (!fieldMode.isPresent() || fieldMode.filter(m -> m == Mode.NULLABLE).isPresent()) { - fieldDescriptorBuilder = fieldDescriptorBuilder.setLabel(Label.LABEL_OPTIONAL); - } else { - fieldDescriptorBuilder = fieldDescriptorBuilder.setLabel(Label.LABEL_REQUIRED); - } - descriptorBuilder.addField(fieldDescriptorBuilder.build()); - } - - /** - * Given a BigQuery TableRow, returns a protocol-buffer message that can be used to write data - * using the Vortex streaming API. - */ - public static DynamicMessage messageFromTableRow(Descriptor descriptor, TableRow tableRow) { - DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); - for (FieldDescriptor fieldDescriptor : descriptor.getFields()) { - Object value = - messageValueFromFieldValue(fieldDescriptor, tableRow.get(fieldDescriptor.getName())); - if (value != null) { - builder.setField(fieldDescriptor, value); - } - } - return builder.build(); - } - - static Object messageValueFromFieldValue(FieldDescriptor fieldDescriptor, Object bqValue) { - if (bqValue == null) { - if (fieldDescriptor.isOptional()) { - return null; - } else { - throw new IllegalArgumentException( - "Received null value for non-nullable field " + fieldDescriptor.getName()); - } - } - return toProtoValue(fieldDescriptor, bqValue); - } - - private static final Map<FieldDescriptor.Type, Function<String, Object>> JSON_PROTO_PARSERS = - ImmutableMap.<FieldDescriptor.Type, Function<String, Object>>builder() - .put(FieldDescriptor.Type.INT32, Integer::valueOf) - .put(FieldDescriptor.Type.INT64, Long::valueOf) - .put(FieldDescriptor.Type.FLOAT, Float::valueOf) - .put(FieldDescriptor.Type.DOUBLE, Double::valueOf) - .put(FieldDescriptor.Type.BOOL, Boolean::valueOf) - .put(FieldDescriptor.Type.STRING, str -> str) - .put( - FieldDescriptor.Type.BYTES, - b64 -> ByteString.copyFrom(BaseEncoding.base64().decode(b64))) - .build(); - - private static Object toProtoValue(FieldDescriptor fieldDescriptor, Object jsonBQValue) { - if (jsonBQValue instanceof String) { - Function<String, Object> mapper = JSON_PROTO_PARSERS.get(fieldDescriptor.getType()); - if (mapper != null) { - return mapper.apply((String) jsonBQValue); - } - } else if (jsonBQValue instanceof Integer) { - switch (fieldDescriptor.getJavaType()) { - case INT: - return Integer.valueOf((int) jsonBQValue); - case LONG: - return Long.valueOf((int) jsonBQValue); - default: - throw new RuntimeException( - "Unexpectecd java type " - + jsonBQValue.getClass() - + " for field descriptor " - + fieldDescriptor); - } - } else if (jsonBQValue instanceof List) { - return ((List<Object>) jsonBQValue) - .stream() - .map(v -> ((Map<String, Object>) v).get("v")) - .map(v -> toProtoValue(fieldDescriptor, v)) - .collect(toList()); - } else if (jsonBQValue instanceof AbstractMap) { - // This will handle nested rows. - TableRow tr = new TableRow(); - tr.putAll((AbstractMap<String, Object>) jsonBQValue); - return messageFromTableRow(fieldDescriptor.getMessageType(), tr); - } else { - return toProtoValue(fieldDescriptor, jsonBQValue.toString()); - } - - throw new UnsupportedOperationException( - "Converting BigQuery type '" - + jsonBQValue.getClass() - + "' to '" - + fieldDescriptor - + "' is not supported"); - } - - public static TableRow tableRowFromMessage(Message message) { - TableRow tableRow = new TableRow(); - for (Map.Entry<FieldDescriptor, Object> field : message.getAllFields().entrySet()) { - FieldDescriptor fieldDescriptor = field.getKey(); - Object fieldValue = field.getValue(); - tableRow.putIfAbsent( - fieldDescriptor.getName(), jsonValueFromMessageValue(fieldDescriptor, fieldValue, true)); - } - return tableRow; - } - - public static Object jsonValueFromMessageValue( - FieldDescriptor fieldDescriptor, Object fieldValue, boolean expandRepeated) { - if (expandRepeated && fieldDescriptor.isRepeated()) { - List<Object> valueList = (List<Object>) fieldValue; - return valueList.stream() - .map(v -> jsonValueFromMessageValue(fieldDescriptor, v, false)) - .collect(toList()); - } - - switch (fieldDescriptor.getType()) { - case GROUP: - case MESSAGE: - return tableRowFromMessage((Message) fieldValue); - case BYTES: - return BaseEncoding.base64().encode(((ByteString) fieldValue).toByteArray()); - case ENUM: - throw new RuntimeException("Enumerations not supported"); - default: - return fieldValue; - } - } - /** Convert a list of BigQuery {@link TableFieldSchema} to Avro {@link org.apache.avro.Schema}. */ @Experimental(Kind.SCHEMAS) public static org.apache.avro.Schema toGenericAvroSchema( diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CivilTimeEncoder.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CivilTimeEncoder.java new file mode 100644 index 0000000..bed767b --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CivilTimeEncoder.java @@ -0,0 +1,648 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; + +import java.time.temporal.ChronoUnit; +import org.joda.time.LocalDateTime; +import org.joda.time.LocalTime; + +/** + * Encoder for TIME and DATETIME values, according to civil_time encoding. Copied out of the zetasql + * package. + * + * <p>The valid range and number of bits required by each date/time field is as the following: + * + * <table> + * <tr> <th> Field </th> <th> Range </th> <th> #Bits </th> </tr> + * <tr> <td> Year </td> <td> [1, 9999] </td> <td> 14 </td> </tr> + * <tr> <td> Month </td> <td> [1, 12] </td> <td> 4 </td> </tr> + * <tr> <td> Day </td> <td> [1, 31] </td> <td> 5 </td> </tr> + * <tr> <td> Hour </td> <td> [0, 23] </td> <td> 5 </td> </tr> + * <tr> <td> Minute </td> <td> [0, 59] </td> <td> 6 </td> </tr> + * <tr> <td> Second </td> <td> [0, 59]* </td> <td> 6 </td> </tr> + * <tr> <td> Micros </td> <td> [0, 999999] </td> <td> 20 </td> </tr> + * <tr> <td> Nanos </td> <td> [0, 999999999] </td> <td> 30 </td> </tr> + * </table> + * + * <p>* Leap second is not supported. + * + * <p>When encoding the TIME or DATETIME into a bit field, larger date/time field is on the more + * significant side. + */ +public final class CivilTimeEncoder { + private static final int NANO_LENGTH = 30; + private static final int MICRO_LENGTH = 20; + + private static final int NANO_SHIFT = 0; + private static final int MICRO_SHIFT = 0; + private static final int SECOND_SHIFT = 0; + private static final int MINUTE_SHIFT = 6; + private static final int HOUR_SHIFT = 12; + private static final int DAY_SHIFT = 17; + private static final int MONTH_SHIFT = 22; + private static final int YEAR_SHIFT = 26; + + private static final long NANO_MASK = 0x3FFFFFFFL; + private static final long MICRO_MASK = 0xFFFFFL; + private static final long SECOND_MASK = 0x3FL; + private static final long MINUTE_MASK = 0xFC0L; + private static final long HOUR_MASK = 0x1F000L; + private static final long DAY_MASK = 0x3E0000L; + private static final long MONTH_MASK = 0x3C00000L; + private static final long YEAR_MASK = 0xFFFC000000L; + + private static final long TIME_SECONDS_MASK = 0x1FFFFL; + private static final long TIME_MICROS_MASK = 0x1FFFFFFFFFL; + private static final long TIME_NANOS_MASK = 0x7FFFFFFFFFFFL; + private static final long DATETIME_SECONDS_MASK = 0xFFFFFFFFFFL; + private static final long DATETIME_MICROS_MASK = 0xFFFFFFFFFFFFFFFL; + + /** + * Encodes {@code time} as a 4-byte integer with seconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 3 2 1 + * MSB 10987654321098765432109876543210 LSB + * | H || M || S | + * </pre> + * + * @see #decodePacked32TimeSeconds(int) + * @see #encodePacked32TimeSeconds(java.time.LocalTime) + */ + @SuppressWarnings("GoodTime") // should accept a java.time.LocalTime + public static int encodePacked32TimeSeconds(LocalTime time) { + checkValidTimeSeconds(time); + int bitFieldTimeSeconds = 0x0; + bitFieldTimeSeconds |= time.getHourOfDay() << HOUR_SHIFT; + bitFieldTimeSeconds |= time.getMinuteOfHour() << MINUTE_SHIFT; + bitFieldTimeSeconds |= time.getSecondOfMinute() << SECOND_SHIFT; + return bitFieldTimeSeconds; + } + + /** + * Encodes {@code time} as a 4-byte integer with seconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 3 2 1 + * MSB 10987654321098765432109876543210 LSB + * | H || M || S | + * </pre> + * + * @see #decodePacked32TimeSecondsAsJavaTime(int) + */ + @SuppressWarnings("GoodTime-ApiWithNumericTimeUnit") + public static int encodePacked32TimeSeconds(java.time.LocalTime time) { + checkValidTimeSeconds(time); + int bitFieldTimeSeconds = 0x0; + bitFieldTimeSeconds |= time.getHour() << HOUR_SHIFT; + bitFieldTimeSeconds |= time.getMinute() << MINUTE_SHIFT; + bitFieldTimeSeconds |= time.getSecond() << SECOND_SHIFT; + return bitFieldTimeSeconds; + } + + /** + * Decodes {@code bitFieldTimeSeconds} as a {@link LocalTime} with seconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 3 2 1 + * MSB 10987654321098765432109876543210 LSB + * | H || M || S | + * </pre> + * + * @see #encodePacked32TimeSeconds(LocalTime) + * @see #encodePacked32TimeSecondsAsJavaTime(int) + */ + @SuppressWarnings("GoodTime") // should return a java.time.LocalTime + public static LocalTime decodePacked32TimeSeconds(int bitFieldTimeSeconds) { + checkValidBitField(bitFieldTimeSeconds, TIME_SECONDS_MASK); + int hourOfDay = getFieldFromBitField(bitFieldTimeSeconds, HOUR_MASK, HOUR_SHIFT); + int minuteOfHour = getFieldFromBitField(bitFieldTimeSeconds, MINUTE_MASK, MINUTE_SHIFT); + int secondOfMinute = getFieldFromBitField(bitFieldTimeSeconds, SECOND_MASK, SECOND_SHIFT); + LocalTime time = new LocalTime(hourOfDay, minuteOfHour, secondOfMinute); + checkValidTimeSeconds(time); + return time; + } + + /** + * Decodes {@code bitFieldTimeSeconds} as a {@link java.time.LocalTime} with seconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 3 2 1 + * MSB 10987654321098765432109876543210 LSB + * | H || M || S | + * </pre> + * + * @see #encodePacked32TimeSeconds(java.time.LocalTime) + */ + @SuppressWarnings("GoodTime-ApiWithNumericTimeUnit") + public static java.time.LocalTime decodePacked32TimeSecondsAsJavaTime(int bitFieldTimeSeconds) { + checkValidBitField(bitFieldTimeSeconds, TIME_SECONDS_MASK); + int hourOfDay = getFieldFromBitField(bitFieldTimeSeconds, HOUR_MASK, HOUR_SHIFT); + int minuteOfHour = getFieldFromBitField(bitFieldTimeSeconds, MINUTE_MASK, MINUTE_SHIFT); + int secondOfMinute = getFieldFromBitField(bitFieldTimeSeconds, SECOND_MASK, SECOND_SHIFT); + // java.time.LocalTime validates the input parameters. + try { + return java.time.LocalTime.of(hourOfDay, minuteOfHour, secondOfMinute); + } catch (java.time.DateTimeException e) { + throw new IllegalArgumentException(e.getMessage(), e); + } + } + + /** + * Encodes {@code time} as a 8-byte integer with microseconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * | H || M || S ||-------micros-----| + * </pre> + * + * @see #decodePacked64TimeMicros(long) + * @see #encodePacked64TimeMicros(java.time.LocalTime) + */ + @SuppressWarnings("GoodTime") // should accept a java.time.LocalTime + public static long encodePacked64TimeMicros(LocalTime time) { + checkValidTimeMillis(time); + return (((long) encodePacked32TimeSeconds(time)) << MICRO_LENGTH) + | (time.getMillisOfSecond() * 1_000L); + } + + /** + * Encodes {@code time} as a 8-byte integer with microseconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * | H || M || S ||-------micros-----| + * </pre> + * + * @see #decodePacked64TimeMicrosAsJavaTime(long) + */ + @SuppressWarnings({"GoodTime-ApiWithNumericTimeUnit", "JavaLocalTimeGetNano"}) + public static long encodePacked64TimeMicros(java.time.LocalTime time) { + checkValidTimeMicros(time); + return (((long) encodePacked32TimeSeconds(time)) << MICRO_LENGTH) | (time.getNano() / 1_000L); + } + + /** + * Decodes {@code bitFieldTimeMicros} as a {@link LocalTime} with microseconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * | H || M || S ||-------micros-----| + * </pre> + * + * <p><b>Warning: LocalTime only supports milliseconds precision. Result is truncated.</b> + * + * @see #encodePacked64TimeMicros(LocalTime) + * @see #decodePacked64TimeMicrosAsJavaTime(long) + */ + @SuppressWarnings("GoodTime") // should return a java.time.LocalTime + public static LocalTime decodePacked64TimeMicros(long bitFieldTimeMicros) { + checkValidBitField(bitFieldTimeMicros, TIME_MICROS_MASK); + int bitFieldTimeSeconds = (int) (bitFieldTimeMicros >> MICRO_LENGTH); + LocalTime timeSeconds = decodePacked32TimeSeconds(bitFieldTimeSeconds); + int microOfSecond = getFieldFromBitField(bitFieldTimeMicros, MICRO_MASK, MICRO_SHIFT); + checkValidMicroOfSecond(microOfSecond); + LocalTime time = timeSeconds.withMillisOfSecond(microOfSecond / 1_000); + checkValidTimeMillis(time); + return time; + } + + /** + * Decodes {@code bitFieldTimeMicros} as a {@link java.time.LocalTime} with microseconds + * precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * | H || M || S ||-------micros-----| + * </pre> + * + * @see #encodePacked64TimeMicros(java.time.LocalTime) + */ + @SuppressWarnings("GoodTime-ApiWithNumericTimeUnit") + public static java.time.LocalTime decodePacked64TimeMicrosAsJavaTime(long bitFieldTimeMicros) { + checkValidBitField(bitFieldTimeMicros, TIME_MICROS_MASK); + int bitFieldTimeSeconds = (int) (bitFieldTimeMicros >> MICRO_LENGTH); + java.time.LocalTime timeSeconds = decodePacked32TimeSecondsAsJavaTime(bitFieldTimeSeconds); + int microOfSecond = getFieldFromBitField(bitFieldTimeMicros, MICRO_MASK, MICRO_SHIFT); + checkValidMicroOfSecond(microOfSecond); + return timeSeconds.withNano(microOfSecond * 1_000); + } + + /** + * Encodes {@code time} as a 8-byte integer with nanoseconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * | H || M || S ||---------- nanos -----------| + * </pre> + * + * @see #decodePacked64TimeNanos(long) + * @see #encodePacked64TimeNanos(java.time.LocalTime) + */ + @SuppressWarnings("GoodTime") // should accept a java.time.LocalTime + public static long encodePacked64TimeNanos(LocalTime time) { + checkValidTimeMillis(time); + return (((long) encodePacked32TimeSeconds(time)) << NANO_LENGTH) + | (time.getMillisOfSecond() * 1_000_000L); + } + + /** + * Encodes {@code time} as a 8-byte integer with nanoseconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * | H || M || S ||---------- nanos -----------| + * </pre> + * + * @see #decodePacked64TimeNanosAsJavaTime(long) + */ + @SuppressWarnings({"GoodTime-ApiWithNumericTimeUnit", "JavaLocalTimeGetNano"}) + public static long encodePacked64TimeNanos(java.time.LocalTime time) { + checkValidTimeNanos(time); + return (((long) encodePacked32TimeSeconds(time)) << NANO_LENGTH) | time.getNano(); + } + + /** + * Decodes {@code bitFieldTimeNanos} as a {@link LocalTime} with nanoseconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * | H || M || S ||---------- nanos -----------| + * </pre> + * + * <p><b>Warning: LocalTime only supports milliseconds precision. Result is truncated.</b> + * + * @see #encodePacked64TimeNanos(LocalTime) + * @see #decodePacked64TimeNanosAsJavaTime(long) + */ + @SuppressWarnings("GoodTime") // should return a java.time.LocalTime + public static LocalTime decodePacked64TimeNanos(long bitFieldTimeNanos) { + checkValidBitField(bitFieldTimeNanos, TIME_NANOS_MASK); + int bitFieldTimeSeconds = (int) (bitFieldTimeNanos >> NANO_LENGTH); + LocalTime timeSeconds = decodePacked32TimeSeconds(bitFieldTimeSeconds); + int nanoOfSecond = getFieldFromBitField(bitFieldTimeNanos, NANO_MASK, NANO_SHIFT); + checkValidNanoOfSecond(nanoOfSecond); + LocalTime time = timeSeconds.withMillisOfSecond(nanoOfSecond / 1_000_000); + checkValidTimeMillis(time); + return time; + } + + /** + * Decodes {@code bitFieldTimeNanos} as a {@link java.time.LocalTime} with nanoseconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * | H || M || S ||---------- nanos -----------| + * </pre> + * + * @see #encodePacked64TimeNanos(java.time.LocalTime) + */ + @SuppressWarnings("GoodTime-ApiWithNumericTimeUnit") + public static java.time.LocalTime decodePacked64TimeNanosAsJavaTime(long bitFieldTimeNanos) { + checkValidBitField(bitFieldTimeNanos, TIME_NANOS_MASK); + int bitFieldTimeSeconds = (int) (bitFieldTimeNanos >> NANO_LENGTH); + java.time.LocalTime timeSeconds = decodePacked32TimeSecondsAsJavaTime(bitFieldTimeSeconds); + int nanoOfSecond = getFieldFromBitField(bitFieldTimeNanos, NANO_MASK, NANO_SHIFT); + checkValidNanoOfSecond(nanoOfSecond); + return timeSeconds.withNano(nanoOfSecond); + } + + /** + * Encodes {@code dateTime} as a 8-byte integer with seconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * |--- year ---||m || D || H || M || S | + * </pre> + * + * @see #decodePacked64DatetimeSeconds(long) + * @see #encodePacked64DatetimeSeconds(java.time.LocalDateTime) + */ + @SuppressWarnings("GoodTime") // should accept a java.time.LocalDateTime + public static long encodePacked64DatetimeSeconds(LocalDateTime dateTime) { + checkValidDateTimeSeconds(dateTime); + long bitFieldDatetimeSeconds = 0x0L; + bitFieldDatetimeSeconds |= (long) dateTime.getYear() << YEAR_SHIFT; + bitFieldDatetimeSeconds |= (long) dateTime.getMonthOfYear() << MONTH_SHIFT; + bitFieldDatetimeSeconds |= (long) dateTime.getDayOfMonth() << DAY_SHIFT; + bitFieldDatetimeSeconds |= (long) encodePacked32TimeSeconds(dateTime.toLocalTime()); + return bitFieldDatetimeSeconds; + } + + /** + * Encodes {@code dateTime} as a 8-byte integer with seconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * |--- year ---||m || D || H || M || S | + * </pre> + * + * @see #decodePacked64DatetimeSecondsAsJavaTime(long) + */ + @SuppressWarnings("GoodTime-ApiWithNumericTimeUnit") + public static long encodePacked64DatetimeSeconds(java.time.LocalDateTime dateTime) { + checkValidDateTimeSeconds(dateTime); + long bitFieldDatetimeSeconds = 0x0L; + bitFieldDatetimeSeconds |= (long) dateTime.getYear() << YEAR_SHIFT; + bitFieldDatetimeSeconds |= (long) dateTime.getMonthValue() << MONTH_SHIFT; + bitFieldDatetimeSeconds |= (long) dateTime.getDayOfMonth() << DAY_SHIFT; + bitFieldDatetimeSeconds |= (long) encodePacked32TimeSeconds(dateTime.toLocalTime()); + return bitFieldDatetimeSeconds; + } + + /** + * Decodes {@code bitFieldDatetimeSeconds} as a {@link LocalDateTime} with seconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * |--- year ---||m || D || H || M || S | + * </pre> + * + * @see #encodePacked64DatetimeSeconds(LocalDateTime) + * @see #decodePacked64DatetimeSecondsAsJavaTime(long) + */ + @SuppressWarnings("GoodTime") // should return a java.time.LocalDateTime + public static LocalDateTime decodePacked64DatetimeSeconds(long bitFieldDatetimeSeconds) { + checkValidBitField(bitFieldDatetimeSeconds, DATETIME_SECONDS_MASK); + int bitFieldTimeSeconds = (int) (bitFieldDatetimeSeconds & TIME_SECONDS_MASK); + LocalTime timeSeconds = decodePacked32TimeSeconds(bitFieldTimeSeconds); + int year = getFieldFromBitField(bitFieldDatetimeSeconds, YEAR_MASK, YEAR_SHIFT); + int monthOfYear = getFieldFromBitField(bitFieldDatetimeSeconds, MONTH_MASK, MONTH_SHIFT); + int dayOfMonth = getFieldFromBitField(bitFieldDatetimeSeconds, DAY_MASK, DAY_SHIFT); + LocalDateTime dateTime = + new LocalDateTime( + year, + monthOfYear, + dayOfMonth, + timeSeconds.getHourOfDay(), + timeSeconds.getMinuteOfHour(), + timeSeconds.getSecondOfMinute()); + checkValidDateTimeSeconds(dateTime); + return dateTime; + } + + /** + * Decodes {@code bitFieldDatetimeSeconds} as a {@link LocalDateTime} with seconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * |--- year ---||m || D || H || M || S | + * </pre> + * + * @see #encodePacked64DatetimeSeconds(java.time.LocalDateTime) + */ + @SuppressWarnings("GoodTime-ApiWithNumericTimeUnit") + public static java.time.LocalDateTime decodePacked64DatetimeSecondsAsJavaTime( + long bitFieldDatetimeSeconds) { + checkValidBitField(bitFieldDatetimeSeconds, DATETIME_SECONDS_MASK); + int bitFieldTimeSeconds = (int) (bitFieldDatetimeSeconds & TIME_SECONDS_MASK); + java.time.LocalTime timeSeconds = decodePacked32TimeSecondsAsJavaTime(bitFieldTimeSeconds); + int year = getFieldFromBitField(bitFieldDatetimeSeconds, YEAR_MASK, YEAR_SHIFT); + int monthOfYear = getFieldFromBitField(bitFieldDatetimeSeconds, MONTH_MASK, MONTH_SHIFT); + int dayOfMonth = getFieldFromBitField(bitFieldDatetimeSeconds, DAY_MASK, DAY_SHIFT); + try { + java.time.LocalDateTime dateTime = + java.time.LocalDateTime.of( + year, + monthOfYear, + dayOfMonth, + timeSeconds.getHour(), + timeSeconds.getMinute(), + timeSeconds.getSecond()); + checkValidDateTimeSeconds(dateTime); + return dateTime; + } catch (java.time.DateTimeException e) { + throw new IllegalArgumentException(e.getMessage(), e); + } + } + + /** + * Encodes {@code dateTime} as a 8-byte integer with microseconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * |--- year ---||m || D || H || M || S ||-------micros-----| + * </pre> + * + * @see #decodePacked64DatetimeMicros(long) + * @see #encodePacked64DatetimeMicros(java.time.LocalDateTime) + */ + @SuppressWarnings("GoodTime") // should accept a java.time.LocalDateTime + public static long encodePacked64DatetimeMicros(LocalDateTime dateTime) { + checkValidDateTimeMillis(dateTime); + return (encodePacked64DatetimeSeconds(dateTime) << MICRO_LENGTH) + | (dateTime.getMillisOfSecond() * 1_000L); + } + + /** + * Encodes {@code dateTime} as a 8-byte integer with microseconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * |--- year ---||m || D || H || M || S ||-------micros-----| + * </pre> + * + * @see #decodePacked64DatetimeMicrosAsJavaTime(long) + */ + @SuppressWarnings({"GoodTime-ApiWithNumericTimeUnit", "JavaLocalDateTimeGetNano"}) + public static long encodePacked64DatetimeMicros(java.time.LocalDateTime dateTime) { + checkValidDateTimeMicros(dateTime); + return (encodePacked64DatetimeSeconds(dateTime) << MICRO_LENGTH) + | (dateTime.getNano() / 1_000L); + } + + /** + * Decodes {@code bitFieldDatetimeMicros} as a {@link LocalDateTime} with microseconds precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * |--- year ---||m || D || H || M || S ||-------micros-----| + * </pre> + * + * <p><b>Warning: LocalDateTime only supports milliseconds precision. Result is truncated.</b> + * + * @see #encodePacked64DatetimeMicros(LocalDateTime) + * @see #decodePacked64DatetimeMicrosAsJavaTime(long) + */ + @SuppressWarnings("GoodTime") // should return a java.time.LocalDateTime + public static LocalDateTime decodePacked64DatetimeMicros(long bitFieldDatetimeMicros) { + checkValidBitField(bitFieldDatetimeMicros, DATETIME_MICROS_MASK); + long bitFieldDatetimeSeconds = bitFieldDatetimeMicros >> MICRO_LENGTH; + LocalDateTime dateTimeSeconds = decodePacked64DatetimeSeconds(bitFieldDatetimeSeconds); + int microOfSecond = getFieldFromBitField(bitFieldDatetimeMicros, MICRO_MASK, MICRO_SHIFT); + checkValidMicroOfSecond(microOfSecond); + LocalDateTime dateTime = dateTimeSeconds.withMillisOfSecond(microOfSecond / 1_000); + checkValidDateTimeMillis(dateTime); + return dateTime; + } + + /** + * Decodes {@code bitFieldDatetimeMicros} as a {@link java.time.LocalDateTime} with microseconds + * precision. + * + * <p>Encoding is as the following: + * + * <pre> + * 6 5 4 3 2 1 + * MSB 3210987654321098765432109876543210987654321098765432109876543210 LSB + * |--- year ---||m || D || H || M || S ||-------micros-----| + * </pre> + * + * @see #encodePacked64DatetimeMicros(java.time.LocalDateTime) + */ + @SuppressWarnings("GoodTime-ApiWithNumericTimeUnit") + public static java.time.LocalDateTime decodePacked64DatetimeMicrosAsJavaTime( + long bitFieldDatetimeMicros) { + checkValidBitField(bitFieldDatetimeMicros, DATETIME_MICROS_MASK); + long bitFieldDatetimeSeconds = bitFieldDatetimeMicros >> MICRO_LENGTH; + java.time.LocalDateTime dateTimeSeconds = + decodePacked64DatetimeSecondsAsJavaTime(bitFieldDatetimeSeconds); + int microOfSecond = getFieldFromBitField(bitFieldDatetimeMicros, MICRO_MASK, MICRO_SHIFT); + checkValidMicroOfSecond(microOfSecond); + java.time.LocalDateTime dateTime = dateTimeSeconds.withNano(microOfSecond * 1_000); + checkValidDateTimeMicros(dateTime); + return dateTime; + } + + private static int getFieldFromBitField(long bitField, long mask, int shift) { + return (int) ((bitField & mask) >> shift); + } + + private static void checkValidTimeSeconds(LocalTime time) { + checkArgument(time.getHourOfDay() >= 0 && time.getHourOfDay() <= 23); + checkArgument(time.getMinuteOfHour() >= 0 && time.getMinuteOfHour() <= 59); + checkArgument(time.getSecondOfMinute() >= 0 && time.getSecondOfMinute() <= 59); + } + + private static void checkValidTimeSeconds(java.time.LocalTime time) { + checkArgument(time.getHour() >= 0 && time.getHour() <= 23); + checkArgument(time.getMinute() >= 0 && time.getMinute() <= 59); + checkArgument(time.getSecond() >= 0 && time.getSecond() <= 59); + } + + private static void checkValidTimeMillis(LocalTime time) { + checkValidTimeSeconds(time); + checkArgument(time.getMillisOfSecond() >= 0 && time.getMillisOfSecond() <= 999); + } + + private static void checkValidTimeMicros(java.time.LocalTime time) { + checkValidTimeSeconds(time); + checkArgument(time.equals(time.truncatedTo(ChronoUnit.MICROS))); + } + + private static void checkValidTimeNanos(java.time.LocalTime time) { + checkValidTimeSeconds(time); + } + + private static void checkValidDateTimeSeconds(LocalDateTime dateTime) { + checkArgument(dateTime.getYear() >= 1 && dateTime.getYear() <= 9999); + checkArgument(dateTime.getMonthOfYear() >= 1 && dateTime.getMonthOfYear() <= 12); + checkArgument(dateTime.getDayOfMonth() >= 1 && dateTime.getDayOfMonth() <= 31); + checkValidTimeSeconds(dateTime.toLocalTime()); + } + + private static void checkValidDateTimeSeconds(java.time.LocalDateTime dateTime) { + checkArgument(dateTime.getYear() >= 1 && dateTime.getYear() <= 9999); + checkArgument(dateTime.getMonthValue() >= 1 && dateTime.getMonthValue() <= 12); + checkArgument(dateTime.getDayOfMonth() >= 1 && dateTime.getDayOfMonth() <= 31); + checkValidTimeSeconds(dateTime.toLocalTime()); + } + + private static void checkValidDateTimeMillis(LocalDateTime dateTime) { + checkValidDateTimeSeconds(dateTime); + checkArgument(dateTime.getMillisOfSecond() >= 0 && dateTime.getMillisOfSecond() <= 999); + } + + private static void checkValidDateTimeMicros(java.time.LocalDateTime dateTime) { + checkValidDateTimeSeconds(dateTime); + checkArgument(dateTime.equals(dateTime.truncatedTo(ChronoUnit.MICROS))); + } + + private static void checkValidDateTimeNanos(java.time.LocalDateTime dateTime) { + checkValidDateTimeSeconds(dateTime); + } + + private static void checkValidMicroOfSecond(int microOfSecond) { + checkArgument(microOfSecond >= 0 && microOfSecond <= 999999); + } + + private static void checkValidNanoOfSecond(int nanoOfSecond) { + checkArgument(nanoOfSecond >= 0 && nanoOfSecond <= 999999999); + } + + private static void checkValidBitField(long bitField, long mask) { + checkArgument((bitField & ~mask) == 0x0L); + } + + private CivilTimeEncoder() {} +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java new file mode 100644 index 0000000..873832f --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static java.util.stream.Collectors.toList; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.protobuf.ByteString; +import com.google.protobuf.DescriptorProtos.DescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Label; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type; +import com.google.protobuf.DescriptorProtos.FileDescriptorProto; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.DescriptorValidationException; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; +import java.util.AbstractMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.function.Function; +import javax.annotation.Nullable; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.BaseEncoding; + +/** + * Utility methods for converting JSON {@link TableRow} objects to dynamic protocol message, for use + * with the Storage write API. + */ +public class TableRowToStorageApiProto { + static final Map<String, Type> PRIMITIVE_TYPES = + ImmutableMap.<String, Type>builder() + .put("INT64", Type.TYPE_INT64) + .put("INTEGER", Type.TYPE_INT64) + .put("FLOAT64", Type.TYPE_DOUBLE) + .put("FLOAT", Type.TYPE_DOUBLE) + .put("STRING", Type.TYPE_STRING) + .put("BOOL", Type.TYPE_BOOL) + .put("BOOLEAN", Type.TYPE_BOOL) + .put("BYTES", Type.TYPE_BYTES) + .put("NUMERIC", Type.TYPE_STRING) // Pass through the JSON encoding. + .put("BIGNUMERIC", Type.TYPE_STRING) // Pass through the JSON encoding. + .put("GEOGRAPHY", Type.TYPE_STRING) // Pass through the JSON encoding. + .put("DATE", Type.TYPE_STRING) // Pass through the JSON encoding. + .put("TIME", Type.TYPE_STRING) // Pass through the JSON encoding. + .put("DATETIME", Type.TYPE_STRING) // Pass through the JSON encoding. + .put("TIMESTAMP", Type.TYPE_STRING) // Pass through the JSON encoding. + .build(); + + /** + * Given a BigQuery TableSchema, returns a protocol-buffer Descriptor that can be used to write + * data using the BigQuery Storage API. + */ + public static Descriptor getDescriptorFromTableSchema(TableSchema jsonSchema) + throws DescriptorValidationException { + DescriptorProto descriptorProto = descriptorSchemaFromTableSchema(jsonSchema); + FileDescriptorProto fileDescriptorProto = + FileDescriptorProto.newBuilder().addMessageType(descriptorProto).build(); + FileDescriptor fileDescriptor = + FileDescriptor.buildFrom(fileDescriptorProto, new FileDescriptor[0]); + + return Iterables.getOnlyElement(fileDescriptor.getMessageTypes()); + } + + /** + * Given a BigQuery TableRow, returns a protocol-buffer message that can be used to write data + * using the BigQuery Storage API. + */ + public static DynamicMessage messageFromTableRow(Descriptor descriptor, TableRow tableRow) { + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + for (Map.Entry<String, Object> entry : tableRow.entrySet()) { + FieldDescriptor fieldDescriptor = descriptor.findFieldByName(entry.getKey().toLowerCase()); + Object value = messageValueFromFieldValue(fieldDescriptor, entry.getValue()); + if (value != null) { + builder.setField(fieldDescriptor, value); + } + } + return builder.build(); + } + + @VisibleForTesting + static DescriptorProto descriptorSchemaFromTableSchema(TableSchema tableSchema) { + return descriptorSchemaFromTableFieldSchemas(tableSchema.getFields()); + } + + private static DescriptorProto descriptorSchemaFromTableFieldSchemas( + Iterable<TableFieldSchema> tableFieldSchemas) { + DescriptorProto.Builder descriptorBuilder = DescriptorProto.newBuilder(); + // Create a unique name for the descriptor ('-' characters cannot be used). + descriptorBuilder.setName("D" + UUID.randomUUID().toString().replace("-", "_")); + int i = 1; + for (TableFieldSchema fieldSchema : tableFieldSchemas) { + fieldDescriptorFromTableField(fieldSchema, i++, descriptorBuilder); + } + return descriptorBuilder.build(); + } + + private static void fieldDescriptorFromTableField( + TableFieldSchema fieldSchema, int fieldNumber, DescriptorProto.Builder descriptorBuilder) { + FieldDescriptorProto.Builder fieldDescriptorBuilder = FieldDescriptorProto.newBuilder(); + fieldDescriptorBuilder = fieldDescriptorBuilder.setName(fieldSchema.getName().toLowerCase()); + fieldDescriptorBuilder = fieldDescriptorBuilder.setNumber(fieldNumber); + switch (fieldSchema.getType()) { + case "STRUCT": + case "RECORD": + DescriptorProto nested = descriptorSchemaFromTableFieldSchemas(fieldSchema.getFields()); + descriptorBuilder.addNestedType(nested); + fieldDescriptorBuilder = + fieldDescriptorBuilder.setType(Type.TYPE_MESSAGE).setTypeName(nested.getName()); + break; + default: + @Nullable Type type = PRIMITIVE_TYPES.get(fieldSchema.getType()); + if (type == null) { + throw new UnsupportedOperationException( + "Converting BigQuery type " + fieldSchema.getType() + " to Beam type is unsupported"); + } + fieldDescriptorBuilder = fieldDescriptorBuilder.setType(type); + } + + Optional<Mode> fieldMode = Optional.ofNullable(fieldSchema.getMode()).map(Mode::valueOf); + if (fieldMode.filter(m -> m == Mode.REPEATED).isPresent()) { + fieldDescriptorBuilder = fieldDescriptorBuilder.setLabel(Label.LABEL_REPEATED); + } else if (!fieldMode.isPresent() || fieldMode.filter(m -> m == Mode.NULLABLE).isPresent()) { + fieldDescriptorBuilder = fieldDescriptorBuilder.setLabel(Label.LABEL_OPTIONAL); + } else { + fieldDescriptorBuilder = fieldDescriptorBuilder.setLabel(Label.LABEL_REQUIRED); + } + descriptorBuilder.addField(fieldDescriptorBuilder.build()); + } + + @Nullable + private static Object messageValueFromFieldValue( + FieldDescriptor fieldDescriptor, Object bqValue) { + if (bqValue == null) { + if (fieldDescriptor.isOptional()) { + return null; + } else { + throw new IllegalArgumentException( + "Received null value for non-nullable field " + fieldDescriptor.getName()); + } + } + return toProtoValue(fieldDescriptor, bqValue, fieldDescriptor.isRepeated()); + } + + private static final Map<FieldDescriptor.Type, Function<String, Object>> + JSON_PROTO_STRING_PARSERS = + ImmutableMap.<FieldDescriptor.Type, Function<String, Object>>builder() + .put(FieldDescriptor.Type.INT32, Integer::valueOf) + .put(FieldDescriptor.Type.INT64, Long::valueOf) + .put(FieldDescriptor.Type.FLOAT, Float::valueOf) + .put(FieldDescriptor.Type.DOUBLE, Double::valueOf) + .put(FieldDescriptor.Type.BOOL, Boolean::valueOf) + .put(FieldDescriptor.Type.STRING, str -> str) + .put( + FieldDescriptor.Type.BYTES, + b64 -> ByteString.copyFrom(BaseEncoding.base64().decode(b64))) + .build(); + + @Nullable + @SuppressWarnings({"nullness"}) + @VisibleForTesting + static Object toProtoValue( + FieldDescriptor fieldDescriptor, Object jsonBQValue, boolean isRepeated) { + if (isRepeated) { + return ((List<Object>) jsonBQValue) + .stream() + .map( + v -> { + if (fieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE) { + return ((Map<String, Object>) v).get("v"); + } else { + return v; + } + }) + .map(v -> toProtoValue(fieldDescriptor, v, false)) + .collect(toList()); + } + + if (fieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE) { + if (jsonBQValue instanceof AbstractMap) { + // This will handle nested rows. + TableRow tr = new TableRow(); + tr.putAll((AbstractMap<String, Object>) jsonBQValue); + return messageFromTableRow(fieldDescriptor.getMessageType(), tr); + } else { + throw new RuntimeException("Unexpected value " + jsonBQValue + " Expected a JSON map."); + } + } + @Nullable Object scalarValue = scalarToProtoValue(fieldDescriptor, jsonBQValue); + if (scalarValue == null) { + return toProtoValue(fieldDescriptor, jsonBQValue.toString(), isRepeated); + } else { + return scalarValue; + } + } + + @VisibleForTesting + @Nullable + static Object scalarToProtoValue(FieldDescriptor fieldDescriptor, Object jsonBQValue) { + if (jsonBQValue instanceof String) { + Function<String, Object> mapper = JSON_PROTO_STRING_PARSERS.get(fieldDescriptor.getType()); + if (mapper == null) { + throw new UnsupportedOperationException( + "Converting BigQuery type '" + + jsonBQValue.getClass() + + "' to '" + + fieldDescriptor + + "' is not supported"); + } + return mapper.apply((String) jsonBQValue); + } + + switch (fieldDescriptor.getType()) { + case BOOL: + if (jsonBQValue instanceof Boolean) { + return jsonBQValue; + } + break; + case BYTES: + break; + case INT64: + if (jsonBQValue instanceof Integer) { + return Long.valueOf((Integer) jsonBQValue); + } else if (jsonBQValue instanceof Long) { + return jsonBQValue; + } + break; + case INT32: + if (jsonBQValue instanceof Integer) { + return jsonBQValue; + } + break; + case STRING: + break; + case DOUBLE: + if (jsonBQValue instanceof Double) { + return jsonBQValue; + } else if (jsonBQValue instanceof Float) { + return Double.valueOf((Float) jsonBQValue); + } + break; + default: + throw new RuntimeException("Unsupported proto type " + fieldDescriptor.getType()); + } + return null; + } + + @VisibleForTesting + public static TableRow tableRowFromMessage(Message message) { + TableRow tableRow = new TableRow(); + for (Map.Entry<FieldDescriptor, Object> field : message.getAllFields().entrySet()) { + FieldDescriptor fieldDescriptor = field.getKey(); + Object fieldValue = field.getValue(); + tableRow.putIfAbsent( + fieldDescriptor.getName(), jsonValueFromMessageValue(fieldDescriptor, fieldValue, true)); + } + return tableRow; + } + + public static Object jsonValueFromMessageValue( + FieldDescriptor fieldDescriptor, Object fieldValue, boolean expandRepeated) { + if (expandRepeated && fieldDescriptor.isRepeated()) { + List<Object> valueList = (List<Object>) fieldValue; + return valueList.stream() + .map(v -> jsonValueFromMessageValue(fieldDescriptor, v, false)) + .collect(toList()); + } + + switch (fieldDescriptor.getType()) { + case GROUP: + case MESSAGE: + return tableRowFromMessage((Message) fieldValue); + case BYTES: + return BaseEncoding.base64().encode(((ByteString) fieldValue).toByteArray()); + case ENUM: + throw new RuntimeException("Enumerations not supported"); + default: + return fieldValue; + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java new file mode 100644 index 0000000..32c0b7a --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java @@ -0,0 +1,391 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static org.junit.Assert.assertEquals; + +import com.google.protobuf.ByteString; +import com.google.protobuf.DescriptorProtos.DescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Label; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Functions; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +@SuppressWarnings({ + "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402) +}) +/** Unit tests form {@link BeamRowToStorageApiProto}. */ +public class BeamRowToStorageApiProtoTest { + private static final EnumerationType TEST_ENUM = + EnumerationType.create("ONE", "TWO", "RED", "BLUE"); + private static final Schema BASE_SCHEMA = + Schema.builder() + .addField("byteValue", FieldType.BYTE.withNullable(true)) + .addField("int16Value", FieldType.INT16) + .addField("int32Value", FieldType.INT32.withNullable(true)) + .addField("int64Value", FieldType.INT64.withNullable(true)) + .addField("decimalValue", FieldType.DECIMAL.withNullable(true)) + .addField("floatValue", FieldType.FLOAT.withNullable(true)) + .addField("doubleValue", FieldType.DOUBLE.withNullable(true)) + .addField("stringValue", FieldType.STRING.withNullable(true)) + .addField("datetimeValue", FieldType.DATETIME.withNullable(true)) + .addField("booleanValue", FieldType.BOOLEAN.withNullable(true)) + .addField("bytesValue", FieldType.BYTES.withNullable(true)) + .addField("arrayValue", FieldType.array(FieldType.STRING)) + .addField("iterableValue", FieldType.array(FieldType.STRING)) + .addField("sqlDateValue", FieldType.logicalType(SqlTypes.DATE).withNullable(true)) + .addField("sqlTimeValue", FieldType.logicalType(SqlTypes.TIME).withNullable(true)) + .addField("sqlDatetimeValue", FieldType.logicalType(SqlTypes.DATETIME).withNullable(true)) + .addField( + "sqlTimestampValue", FieldType.logicalType(SqlTypes.TIMESTAMP).withNullable(true)) + .addField("enumValue", FieldType.logicalType(TEST_ENUM).withNullable(true)) + .build(); + + private static final DescriptorProto BASE_SCHEMA_PROTO = + DescriptorProto.newBuilder() + .addField( + FieldDescriptorProto.newBuilder() + .setName("bytevalue") + .setNumber(1) + .setType(Type.TYPE_INT32) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("int16value") + .setNumber(2) + .setType(Type.TYPE_INT32) + .setLabel(Label.LABEL_REQUIRED) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("int32value") + .setNumber(3) + .setType(Type.TYPE_INT32) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("int64value") + .setNumber(4) + .setType(Type.TYPE_INT64) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("decimalvalue") + .setNumber(5) + .setType(Type.TYPE_BYTES) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("floatvalue") + .setNumber(6) + .setType(Type.TYPE_FLOAT) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("doublevalue") + .setNumber(7) + .setType(Type.TYPE_DOUBLE) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("stringvalue") + .setNumber(8) + .setType(Type.TYPE_STRING) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("datetimevalue") + .setNumber(9) + .setType(Type.TYPE_INT64) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("booleanvalue") + .setNumber(10) + .setType(Type.TYPE_BOOL) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("bytesvalue") + .setNumber(11) + .setType(Type.TYPE_BYTES) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("arrayvalue") + .setNumber(12) + .setType(Type.TYPE_STRING) + .setLabel(Label.LABEL_REPEATED) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("iterablevalue") + .setNumber(13) + .setType(Type.TYPE_STRING) + .setLabel(Label.LABEL_REPEATED) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("sqldatevalue") + .setNumber(14) + .setType(Type.TYPE_INT32) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("sqltimevalue") + .setNumber(15) + .setType(Type.TYPE_INT64) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("sqldatetimevalue") + .setNumber(16) + .setType(Type.TYPE_INT64) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("sqltimestampvalue") + .setNumber(17) + .setType(Type.TYPE_INT64) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("enumvalue") + .setNumber(18) + .setType(Type.TYPE_STRING) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .build(); + + private static final byte[] BYTES = "BYTE BYTE BYTE".getBytes(StandardCharsets.UTF_8); + private static final Row BASE_ROW = + Row.withSchema(BASE_SCHEMA) + .withFieldValue("byteValue", (byte) 1) + .withFieldValue("int16Value", (short) 2) + .withFieldValue("int32Value", (int) 3) + .withFieldValue("int64Value", (long) 4) + .withFieldValue("decimalValue", BigDecimal.valueOf(5)) + .withFieldValue("floatValue", (float) 3.14) + .withFieldValue("doubleValue", (double) 2.68) + .withFieldValue("stringValue", "I am a string. Hear me roar.") + .withFieldValue("datetimeValue", Instant.now()) + .withFieldValue("booleanValue", true) + .withFieldValue("bytesValue", BYTES) + .withFieldValue("arrayValue", ImmutableList.of("one", "two", "red", "blue")) + .withFieldValue("iterableValue", ImmutableList.of("blue", "red", "two", "one")) + .withFieldValue("sqlDateValue", LocalDate.now()) + .withFieldValue("sqlTimeValue", LocalTime.now()) + .withFieldValue("sqlDatetimeValue", LocalDateTime.now()) + .withFieldValue("sqlTimestampValue", java.time.Instant.now()) + .withFieldValue("enumValue", TEST_ENUM.valueOf("RED")) + .build(); + private static final Map<String, Object> BASE_PROTO_EXPECTED_FIELDS = + ImmutableMap.<String, Object>builder() + .put("bytevalue", (int) 1) + .put("int16value", (int) 2) + .put("int32value", (int) 3) + .put("int64value", (long) 4) + .put( + "decimalvalue", + BeamRowToStorageApiProto.serializeBigDecimalToNumeric(BigDecimal.valueOf(5))) + .put("floatvalue", (float) 3.14) + .put("doublevalue", (double) 2.68) + .put("stringvalue", "I am a string. Hear me roar.") + .put("datetimevalue", BASE_ROW.getDateTime("datetimeValue").getMillis() * 1000) + .put("booleanvalue", true) + .put("bytesvalue", ByteString.copyFrom(BYTES)) + .put("arrayvalue", ImmutableList.of("one", "two", "red", "blue")) + .put("iterablevalue", ImmutableList.of("blue", "red", "two", "one")) + .put( + "sqldatevalue", + (int) BASE_ROW.getLogicalTypeValue("sqlDateValue", LocalDate.class).toEpochDay()) + .put( + "sqltimevalue", + CivilTimeEncoder.encodePacked64TimeMicros( + BASE_ROW.getLogicalTypeValue("sqlTimeValue", LocalTime.class))) + .put( + "sqldatetimevalue", + CivilTimeEncoder.encodePacked64DatetimeSeconds( + BASE_ROW.getLogicalTypeValue("sqlDatetimeValue", LocalDateTime.class))) + .put( + "sqltimestampvalue", + BASE_ROW + .getLogicalTypeValue("sqlTimestampValue", java.time.Instant.class) + .toEpochMilli() + * 1000) + .put("enumvalue", "RED") + .build(); + + private static final Schema NESTED_SCHEMA = + Schema.builder() + .addField("nested", FieldType.row(BASE_SCHEMA).withNullable(true)) + .addField("nestedArray", FieldType.array(FieldType.row(BASE_SCHEMA))) + .addField("nestedIterable", FieldType.iterable(FieldType.row(BASE_SCHEMA))) + .build(); + private static final Row NESTED_ROW = + Row.withSchema(NESTED_SCHEMA) + .withFieldValue("nested", BASE_ROW) + .withFieldValue("nestedArray", ImmutableList.of(BASE_ROW, BASE_ROW)) + .withFieldValue("nestedIterable", ImmutableList.of(BASE_ROW, BASE_ROW)) + .build(); + + @Test + public void testDescriptorFromSchema() { + DescriptorProto descriptor = + BeamRowToStorageApiProto.descriptorSchemaFromBeamSchema(BASE_SCHEMA); + Map<String, Type> types = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + Map<String, Type> expectedTypes = + BASE_SCHEMA_PROTO.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + assertEquals(expectedTypes, types); + + Map<String, String> nameMapping = + BASE_SCHEMA.getFields().stream() + .collect(Collectors.toMap(f -> f.getName().toLowerCase(), Field::getName)); + descriptor + .getFieldList() + .forEach( + p -> { + FieldType schemaFieldType = + BASE_SCHEMA.getField(nameMapping.get(p.getName())).getType(); + Label label = + schemaFieldType.getTypeName().isCollectionType() + ? Label.LABEL_REPEATED + : schemaFieldType.getNullable() ? Label.LABEL_OPTIONAL : Label.LABEL_REQUIRED; + assertEquals(label, p.getLabel()); + }); + } + + @Test + public void testNestedFromSchema() { + DescriptorProto descriptor = + BeamRowToStorageApiProto.descriptorSchemaFromBeamSchema(NESTED_SCHEMA); + Map<String, Type> expectedBaseTypes = + BASE_SCHEMA_PROTO.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + + Map<String, Type> types = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + Map<String, String> typeNames = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getTypeName)); + Map<String, Label> typeLabels = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getLabel)); + + assertEquals(3, types.size()); + + Map<String, DescriptorProto> nestedTypes = + descriptor.getNestedTypeList().stream() + .collect(Collectors.toMap(DescriptorProto::getName, Functions.identity())); + assertEquals(3, nestedTypes.size()); + assertEquals(Type.TYPE_MESSAGE, types.get("nested")); + assertEquals(Label.LABEL_OPTIONAL, typeLabels.get("nested")); + String nestedTypeName1 = typeNames.get("nested"); + Map<String, Type> nestedTypes1 = + nestedTypes.get(nestedTypeName1).getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + assertEquals(expectedBaseTypes, nestedTypes1); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedarray")); + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedarray")); + String nestedTypeName2 = typeNames.get("nestedarray"); + Map<String, Type> nestedTypes2 = + nestedTypes.get(nestedTypeName2).getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + assertEquals(expectedBaseTypes, nestedTypes2); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestediterable")); + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestediterable")); + String nestedTypeName3 = typeNames.get("nestediterable"); + Map<String, Type> nestedTypes3 = + nestedTypes.get(nestedTypeName3).getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + assertEquals(expectedBaseTypes, nestedTypes3); + } + + private void assertBaseRecord(DynamicMessage msg) { + Map<String, Object> recordFields = + msg.getAllFields().entrySet().stream() + .collect( + Collectors.toMap(entry -> entry.getKey().getName(), entry -> entry.getValue())); + assertEquals(BASE_PROTO_EXPECTED_FIELDS, recordFields); + } + + @Test + public void testMessageFromTableRow() throws Exception { + Descriptor descriptor = BeamRowToStorageApiProto.getDescriptorFromSchema(NESTED_SCHEMA); + DynamicMessage msg = BeamRowToStorageApiProto.messageFromBeamRow(descriptor, NESTED_ROW); + assertEquals(3, msg.getAllFields().size()); + + Map<String, FieldDescriptor> fieldDescriptors = + descriptor.getFields().stream() + .collect(Collectors.toMap(FieldDescriptor::getName, Functions.identity())); + DynamicMessage nestedMsg = (DynamicMessage) msg.getField(fieldDescriptors.get("nested")); + assertBaseRecord(nestedMsg); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java index d0d4c31..1c856cc 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java @@ -31,14 +31,6 @@ import static org.junit.Assert.assertThrows; import com.google.api.services.bigquery.model.TableFieldSchema; import com.google.api.services.bigquery.model.TableRow; import com.google.api.services.bigquery.model.TableSchema; -import com.google.protobuf.ByteString; -import com.google.protobuf.DescriptorProtos.DescriptorProto; -import com.google.protobuf.DescriptorProtos.FieldDescriptorProto; -import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Label; -import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type; -import com.google.protobuf.Descriptors.Descriptor; -import com.google.protobuf.Descriptors.FieldDescriptor; -import com.google.protobuf.DynamicMessage; import java.math.BigDecimal; import java.math.RoundingMode; import java.nio.ByteBuffer; @@ -46,13 +38,10 @@ import java.nio.charset.StandardCharsets; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.LocalTime; -import java.util.AbstractMap; import java.util.Arrays; import java.util.Base64; import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; import org.apache.avro.Conversions; import org.apache.avro.LogicalTypes; import org.apache.avro.generic.GenericData; @@ -64,10 +53,7 @@ import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.schemas.utils.AvroUtils; import org.apache.beam.sdk.values.Row; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Functions; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.BaseEncoding; import org.joda.time.DateTime; import org.joda.time.Instant; import org.joda.time.chrono.ISOChronology; @@ -787,250 +773,4 @@ public class BigQueryUtilsTest { record, AVRO_ARRAY_ARRAY_TYPE, BigQueryUtils.ConversionOptions.builder().build()); assertEquals(expected, beamRow); } - - private static final TableSchema BASE_TABLE_SCHEMA = - new TableSchema() - .setFields( - ImmutableList.<TableFieldSchema>builder() - .add(new TableFieldSchema().setType("STRING").setName("stringValue")) - .add(new TableFieldSchema().setType("BYTES").setName("bytesValue")) - .add(new TableFieldSchema().setType("INT64").setName("int64Value")) - .add(new TableFieldSchema().setType("INTEGER").setName("intValue")) - .add(new TableFieldSchema().setType("FLOAT64").setName("float64Value")) - .add(new TableFieldSchema().setType("FLOAT").setName("floatValue")) - .add(new TableFieldSchema().setType("BOOL").setName("boolValue")) - .add(new TableFieldSchema().setType("BOOLEAN").setName("booleanValue")) - .add(new TableFieldSchema().setType("TIMESTAMP").setName("timestampValue")) - .add(new TableFieldSchema().setType("TIME").setName("timeValue")) - .add(new TableFieldSchema().setType("DATETIME").setName("datetimeValue")) - .add(new TableFieldSchema().setType("DATE").setName("dateValue")) - .build()); - - private static final DescriptorProto BASE_TABLE_SCHEMA_PROTO = - DescriptorProto.newBuilder() - .addField( - FieldDescriptorProto.newBuilder() - .setName("stringValue") - .setNumber(1) - .setType(Type.TYPE_STRING) - .setLabel(Label.LABEL_OPTIONAL) - .build()) - .addField( - FieldDescriptorProto.newBuilder() - .setName("bytesValue") - .setNumber(2) - .setType(Type.TYPE_BYTES) - .setLabel(Label.LABEL_OPTIONAL) - .build()) - .addField( - FieldDescriptorProto.newBuilder() - .setName("int64Value") - .setNumber(3) - .setType(Type.TYPE_INT64) - .setLabel(Label.LABEL_OPTIONAL) - .build()) - .addField( - FieldDescriptorProto.newBuilder() - .setName("intValue") - .setNumber(4) - .setType(Type.TYPE_INT64) - .setLabel(Label.LABEL_OPTIONAL) - .build()) - .addField( - FieldDescriptorProto.newBuilder() - .setName("float64Value") - .setNumber(5) - .setType(Type.TYPE_FLOAT) - .setLabel(Label.LABEL_OPTIONAL) - .build()) - .addField( - FieldDescriptorProto.newBuilder() - .setName("floatValue") - .setNumber(6) - .setType(Type.TYPE_FLOAT) - .setLabel(Label.LABEL_OPTIONAL) - .build()) - .addField( - FieldDescriptorProto.newBuilder() - .setName("boolValue") - .setNumber(7) - .setType(Type.TYPE_BOOL) - .setLabel(Label.LABEL_OPTIONAL) - .build()) - .addField( - FieldDescriptorProto.newBuilder() - .setName("booleanValue") - .setNumber(8) - .setType(Type.TYPE_BOOL) - .setLabel(Label.LABEL_OPTIONAL) - .build()) - .addField( - FieldDescriptorProto.newBuilder() - .setName("timestampValue") - .setNumber(9) - .setType(Type.TYPE_INT64) - .setLabel(Label.LABEL_OPTIONAL) - .build()) - .addField( - FieldDescriptorProto.newBuilder() - .setName("timeValue") - .setNumber(10) - .setType(Type.TYPE_INT64) - .setLabel(Label.LABEL_OPTIONAL) - .build()) - .addField( - FieldDescriptorProto.newBuilder() - .setName("datetimeValue") - .setNumber(11) - .setType(Type.TYPE_INT64) - .setLabel(Label.LABEL_OPTIONAL) - .build()) - .addField( - FieldDescriptorProto.newBuilder() - .setName("dateValue") - .setNumber(12) - .setType(Type.TYPE_INT32) - .setLabel(Label.LABEL_OPTIONAL) - .build()) - .build(); - - private static final TableSchema NESTED_TABLE_SCHEMA = - new TableSchema() - .setFields( - ImmutableList.<TableFieldSchema>builder() - .add( - new TableFieldSchema() - .setType("STRUCT") - .setName("nestedValue1") - .setFields(BASE_TABLE_SCHEMA.getFields())) - .add( - new TableFieldSchema() - .setType("RECORD") - .setName("nestedValue2") - .setFields(BASE_TABLE_SCHEMA.getFields())) - .build()); - - // For now, test that no exceptions are thrown. - @Test - public void testDescriptorFromTableSchema() { - DescriptorProto descriptor = BigQueryUtils.descriptorSchemaFromTableSchema(BASE_TABLE_SCHEMA); - Map<String, Type> types = - descriptor.getFieldList().stream() - .collect( - Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); - Map<String, Type> expectedTypes = - BASE_TABLE_SCHEMA_PROTO.getFieldList().stream() - .collect( - Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); - assertEquals(expectedTypes, types); - } - - @Test - public void testNestedFromTableSchema() { - DescriptorProto descriptor = BigQueryUtils.descriptorSchemaFromTableSchema(NESTED_TABLE_SCHEMA); - Map<String, Type> expectedBaseTypes = - BASE_TABLE_SCHEMA_PROTO.getFieldList().stream() - .collect( - Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); - - Map<String, Type> types = - descriptor.getFieldList().stream() - .collect( - Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); - Map<String, String> typeNames = - descriptor.getFieldList().stream() - .collect( - Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getTypeName)); - assertEquals(2, types.size()); - - Map<String, DescriptorProto> nestedTypes = - descriptor.getNestedTypeList().stream() - .collect(Collectors.toMap(DescriptorProto::getName, Functions.identity())); - assertEquals(2, nestedTypes.size()); - assertEquals(Type.TYPE_MESSAGE, types.get("nestedValue1")); - String nestedTypeName1 = typeNames.get("nestedValue1"); - Map<String, Type> nestedTypes1 = - nestedTypes.get(nestedTypeName1).getFieldList().stream() - .collect( - Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); - assertEquals(expectedBaseTypes, nestedTypes1); - - assertEquals(Type.TYPE_MESSAGE, types.get("nestedValue2")); - String nestedTypeName2 = typeNames.get("nestedValue2"); - Map<String, Type> nestedTypes2 = - nestedTypes.get(nestedTypeName2).getFieldList().stream() - .collect( - Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); - assertEquals(expectedBaseTypes, nestedTypes2); - } - - @Test - public void testRepeatedDescriptorFromTableSchema() { - BigQueryUtils.descriptorSchemaFromTableSchema(BASE_TABLE_SCHEMA); - } - - private TableRow getBaseRecord() { - return new TableRow() - .set("stringValue", "string") - .set("bytesValue", BaseEncoding.base64().encode("string".getBytes(StandardCharsets.UTF_8))) - .set("int64Value", 42L) - .set("intValue", 43L) - .set("float64Value", (float) 2.8168) - .set("floatValue", (float) 2.817) - .set("boolValue", true) - .set("booleanValue", true) - .set("timestampValue", 1L) - .set("timeValue", 2L) - .set("datetimeValue", 3L) - .set("dateValue", 4); - } - - @Test - public void testMessageFromTableRow() throws Exception { - TableRow baseRecord = getBaseRecord(); - Map<String, Object> baseRecordFields = ImmutableMap.copyOf(baseRecord); - - TableRow tableRow = - new TableRow().set("nestedValue1", baseRecord).set("nestedValue2", baseRecord); - Descriptor descriptor = BigQueryUtils.getDescriptorFromTableSchema(NESTED_TABLE_SCHEMA); - DynamicMessage msg = BigQueryUtils.messageFromTableRow(descriptor, tableRow); - assertEquals(2, msg.getAllFields().size()); - - Map<String, FieldDescriptor> fieldDescriptors = - descriptor.getFields().stream() - .collect(Collectors.toMap(FieldDescriptor::getName, Functions.identity())); - DynamicMessage nestedMsg1 = (DynamicMessage) msg.getField(fieldDescriptors.get("nestedValue1")); - Map<String, Object> nestedMsg1Fields = - nestedMsg1.getAllFields().entrySet().stream() - .map( - entry -> { - if (entry.getKey().getType() == FieldDescriptor.Type.BYTES) { - ByteString byteString = (ByteString) entry.getValue(); - return new AbstractMap.SimpleEntry<>( - entry.getKey(), BaseEncoding.base64().encode(byteString.toByteArray())); - } else { - return entry; - } - }) - .collect( - Collectors.toMap(entry -> entry.getKey().getName(), entry -> entry.getValue())); - assertEquals(baseRecordFields, nestedMsg1Fields); - - DynamicMessage nestedMsg2 = (DynamicMessage) msg.getField(fieldDescriptors.get("nestedValue2")); - Map<String, Object> nestedMsg2Fields = - nestedMsg2.getAllFields().entrySet().stream() - .map( - entry -> { - if (entry.getKey().getType() == FieldDescriptor.Type.BYTES) { - ByteString byteString = (ByteString) entry.getValue(); - return new AbstractMap.SimpleEntry<>( - entry.getKey(), BaseEncoding.base64().encode(byteString.toByteArray())); - } else { - return entry; - } - }) - .collect( - Collectors.toMap(entry -> entry.getKey().getName(), entry -> entry.getValue())); - assertEquals(baseRecordFields, nestedMsg2Fields); - } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoTest.java new file mode 100644 index 0000000..0acc41f --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoTest.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static org.junit.Assert.assertEquals; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.protobuf.ByteString; +import com.google.protobuf.DescriptorProtos.DescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Label; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Functions; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.BaseEncoding; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +@SuppressWarnings({ + "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402) +}) +/** Unit tests for {@link org.apache.beam.sdk.io.gcp.bigquery.TableRowToStorageApiProto}. */ +public class TableRowToStorageApiProtoTest { + private static final TableSchema BASE_TABLE_SCHEMA = + new TableSchema() + .setFields( + ImmutableList.<TableFieldSchema>builder() + .add(new TableFieldSchema().setType("STRING").setName("stringValue")) + .add(new TableFieldSchema().setType("BYTES").setName("bytesValue")) + .add(new TableFieldSchema().setType("INT64").setName("int64Value")) + .add(new TableFieldSchema().setType("INTEGER").setName("intValue")) + .add(new TableFieldSchema().setType("FLOAT64").setName("float64Value")) + .add(new TableFieldSchema().setType("FLOAT").setName("floatValue")) + .add(new TableFieldSchema().setType("BOOL").setName("boolValue")) + .add(new TableFieldSchema().setType("BOOLEAN").setName("booleanValue")) + .add(new TableFieldSchema().setType("TIMESTAMP").setName("timestampValue")) + .add(new TableFieldSchema().setType("TIME").setName("timeValue")) + .add(new TableFieldSchema().setType("DATETIME").setName("datetimeValue")) + .add(new TableFieldSchema().setType("DATE").setName("dateValue")) + .add(new TableFieldSchema().setType("NUMERIC").setName("numericValue")) + .add( + new TableFieldSchema() + .setType("STRING") + .setMode("REPEATED") + .setName("arrayValue")) + .build()); + + private static final DescriptorProto BASE_TABLE_SCHEMA_PROTO = + DescriptorProto.newBuilder() + .addField( + FieldDescriptorProto.newBuilder() + .setName("stringvalue") + .setNumber(1) + .setType(Type.TYPE_STRING) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("bytesvalue") + .setNumber(2) + .setType(Type.TYPE_BYTES) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("int64value") + .setNumber(3) + .setType(Type.TYPE_INT64) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("intvalue") + .setNumber(4) + .setType(Type.TYPE_INT64) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("float64value") + .setNumber(5) + .setType(Type.TYPE_DOUBLE) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("floatvalue") + .setNumber(6) + .setType(Type.TYPE_DOUBLE) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("boolvalue") + .setNumber(7) + .setType(Type.TYPE_BOOL) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("booleanvalue") + .setNumber(8) + .setType(Type.TYPE_BOOL) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("timestampvalue") + .setNumber(9) + .setType(Type.TYPE_STRING) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("timevalue") + .setNumber(10) + .setType(Type.TYPE_STRING) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("datetimevalue") + .setNumber(11) + .setType(Type.TYPE_STRING) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("datevalue") + .setNumber(12) + .setType(Type.TYPE_STRING) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("numericvalue") + .setNumber(13) + .setType(Type.TYPE_STRING) + .setLabel(Label.LABEL_OPTIONAL) + .build()) + .addField( + FieldDescriptorProto.newBuilder() + .setName("arrayvalue") + .setNumber(14) + .setType(Type.TYPE_STRING) + .setLabel(Label.LABEL_REPEATED) + .build()) + .build(); + + private static final TableSchema NESTED_TABLE_SCHEMA = + new TableSchema() + .setFields( + ImmutableList.<TableFieldSchema>builder() + .add( + new TableFieldSchema() + .setType("STRUCT") + .setName("nestedValue1") + .setFields(BASE_TABLE_SCHEMA.getFields())) + .add( + new TableFieldSchema() + .setType("RECORD") + .setName("nestedValue2") + .setFields(BASE_TABLE_SCHEMA.getFields())) + .build()); + + // For now, test that no exceptions are thrown. + @Test + public void testDescriptorFromTableSchema() { + DescriptorProto descriptor = + TableRowToStorageApiProto.descriptorSchemaFromTableSchema(BASE_TABLE_SCHEMA); + Map<String, Type> types = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + Map<String, Type> expectedTypes = + BASE_TABLE_SCHEMA_PROTO.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + assertEquals(expectedTypes, types); + } + + @Test + public void testNestedFromTableSchema() { + DescriptorProto descriptor = + TableRowToStorageApiProto.descriptorSchemaFromTableSchema(NESTED_TABLE_SCHEMA); + Map<String, Type> expectedBaseTypes = + BASE_TABLE_SCHEMA_PROTO.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + + Map<String, Type> types = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + Map<String, String> typeNames = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getTypeName)); + assertEquals(2, types.size()); + + Map<String, DescriptorProto> nestedTypes = + descriptor.getNestedTypeList().stream() + .collect(Collectors.toMap(DescriptorProto::getName, Functions.identity())); + assertEquals(2, nestedTypes.size()); + assertEquals(Type.TYPE_MESSAGE, types.get("nestedvalue1")); + String nestedTypeName1 = typeNames.get("nestedvalue1"); + Map<String, Type> nestedTypes1 = + nestedTypes.get(nestedTypeName1).getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + assertEquals(expectedBaseTypes, nestedTypes1); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedvalue2")); + String nestedTypeName2 = typeNames.get("nestedvalue2"); + Map<String, Type> nestedTypes2 = + nestedTypes.get(nestedTypeName2).getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + assertEquals(expectedBaseTypes, nestedTypes2); + } + + @Test + public void testRepeatedDescriptorFromTableSchema() { + TableRowToStorageApiProto.descriptorSchemaFromTableSchema(BASE_TABLE_SCHEMA); + } + + private static final TableRow BASE_TABLE_ROW = + new TableRow() + .set("stringValue", "string") + .set( + "bytesValue", BaseEncoding.base64().encode("string".getBytes(StandardCharsets.UTF_8))) + .set("int64Value", "42") + .set("intValue", "43") + .set("float64Value", "2.8168") + .set("floatValue", "2.817") + .set("boolValue", "true") + .set("booleanValue", "true") + .set("timestampValue", "43") + .set("timeValue", "00:52:07[.123]|[.123456] UTC") + .set("datetimeValue", "2019-08-16 00:52:07[.123]|[.123456] UTC") + .set("dateValue", "2019-08-16") + .set("numericValue", "23.4") + .set("arrayValue", ImmutableList.of("hello", "goodbye")); + + private static final Map<String, Object> BASE_ROW_EXPECTED_PROTO_VALUES = + ImmutableMap.<String, Object>builder() + .put("stringvalue", "string") + .put("bytesvalue", ByteString.copyFrom("string".getBytes(StandardCharsets.UTF_8))) + .put("int64value", (long) 42) + .put("intvalue", (long) 43) + .put("float64value", (double) 2.8168) + .put("floatvalue", (double) 2.817) + .put("boolvalue", true) + .put("booleanvalue", true) + .put("timestampvalue", "43") + .put("timevalue", "00:52:07[.123]|[.123456] UTC") + .put("datetimevalue", "2019-08-16 00:52:07[.123]|[.123456] UTC") + .put("datevalue", "2019-08-16") + .put("numericvalue", "23.4") + .put("arrayvalue", ImmutableList.of("hello", "goodbye")) + .build(); + + private void assertBaseRecord(DynamicMessage msg) { + Map<String, Object> recordFields = + msg.getAllFields().entrySet().stream() + .collect( + Collectors.toMap(entry -> entry.getKey().getName(), entry -> entry.getValue())); + assertEquals(BASE_ROW_EXPECTED_PROTO_VALUES, recordFields); + } + + @Test + public void testMessageFromTableRow() throws Exception { + TableRow tableRow = + new TableRow().set("nestedValue1", BASE_TABLE_ROW).set("nestedValue2", BASE_TABLE_ROW); + Descriptor descriptor = + TableRowToStorageApiProto.getDescriptorFromTableSchema(NESTED_TABLE_SCHEMA); + DynamicMessage msg = TableRowToStorageApiProto.messageFromTableRow(descriptor, tableRow); + assertEquals(2, msg.getAllFields().size()); + + Map<String, FieldDescriptor> fieldDescriptors = + descriptor.getFields().stream() + .collect(Collectors.toMap(FieldDescriptor::getName, Functions.identity())); + assertBaseRecord((DynamicMessage) msg.getField(fieldDescriptors.get("nestedvalue1"))); + assertBaseRecord((DynamicMessage) msg.getField(fieldDescriptors.get("nestedvalue2"))); + } +}