This is an automated email from the ASF dual-hosted git repository. exceptionfactory pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/nifi.git
The following commit(s) were added to refs/heads/main by this push: new 9141f64ef5 NIFI-13356 Fixed ProtobufReader handling of repeated fields 9141f64ef5 is described below commit 9141f64ef523bc9a150a8b52f453817dad46a637 Author: Mark Bathori <mbath...@apache.org> AuthorDate: Tue Jun 4 12:45:41 2024 +0200 NIFI-13356 Fixed ProtobufReader handling of repeated fields This closes #8922 Signed-off-by: David Handermann <exceptionfact...@apache.org> --- .../protobuf/converter/ProtobufDataConverter.java | 82 +++++++++++++++++++-- .../services/protobuf/converter/ValueReader.java} | 41 ++--------- .../nifi/services/protobuf/ProtoTestUtil.java | 69 ++++++++++++++++- .../converter/TestProtobufDataConverter.java | 46 ++++++++++-- .../protobuf/schema/TestProtoSchemaParser.java | 36 ++++++++- .../src/test/resources/test_proto3.desc | Bin 1022 -> 984 bytes .../src/test/resources/test_proto3.proto | 3 +- .../src/test/resources/test_repeated_proto3.desc | Bin 0 -> 755 bytes ...est_proto3.proto => test_repeated_proto3.proto} | 45 +++++------ 9 files changed, 242 insertions(+), 80 deletions(-) diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ProtobufDataConverter.java b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ProtobufDataConverter.java index df5c491fa9..81e0226d4a 100644 --- a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ProtobufDataConverter.java +++ b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ProtobufDataConverter.java @@ -30,6 +30,7 @@ import org.apache.nifi.serialization.record.DataType; import org.apache.nifi.serialization.record.MapRecord; import org.apache.nifi.serialization.record.RecordField; import org.apache.nifi.serialization.record.RecordSchema; +import org.apache.nifi.serialization.record.type.ArrayDataType; import org.apache.nifi.serialization.record.type.RecordDataType; import org.apache.nifi.serialization.record.util.DataTypeUtils; import org.apache.nifi.services.protobuf.FieldType; @@ -38,6 +39,7 @@ import org.apache.nifi.services.protobuf.schema.ProtoSchemaParser; import java.io.IOException; import java.io.InputStream; import java.math.BigInteger; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -47,6 +49,8 @@ import java.util.function.Function; import static com.google.protobuf.CodedInputStream.decodeZigZag32; import static com.google.protobuf.TextFormat.unsignedToString; +import static org.apache.nifi.services.protobuf.FieldType.STRING; +import static org.apache.nifi.services.protobuf.FieldType.BYTES; /** * The class is responsible for creating Record by mapping the provided proto schema fields with the list of Unknown fields parsed from encoded proto data. @@ -154,7 +158,11 @@ public class ProtobufDataConverter { private Optional<Object> convertFieldValues(ProtoField protoField, UnknownFieldSet.Field unknownField) throws InvalidProtocolBufferException { if (!unknownField.getLengthDelimitedList().isEmpty()) { - return Optional.of(convertLengthDelimitedFields(protoField, unknownField.getLengthDelimitedList())); + if (protoField.isRepeatable() && !isLengthDelimitedType(protoField)) { + return Optional.of(convertRepeatedFields(protoField, unknownField.getLengthDelimitedList())); + } else { + return Optional.of(convertLengthDelimitedFields(protoField, unknownField.getLengthDelimitedList())); + } } if (!unknownField.getFixed32List().isEmpty()) { return Optional.of(convertFixed32Fields(protoField, unknownField.getFixed32List())); @@ -169,6 +177,34 @@ public class ProtobufDataConverter { return Optional.empty(); } + private Object convertRepeatedFields(ProtoField protoField, List<ByteString> fieldValues) { + final CodedInputStream inputStream = fieldValues.getFirst().newCodedInput(); + final ProtoType protoType = protoField.getProtoType(); + if (protoType.isScalar()) { + final ValueReader<CodedInputStream, Object> valueReader = switch (FieldType.findValue(protoType.getSimpleName())) { + case BOOL -> CodedInputStream::readBool; + case INT32 -> CodedInputStream::readInt32; + case UINT32 -> value -> Integer.toUnsignedLong(value.readUInt32()); + case SINT32 -> CodedInputStream::readSInt32; + case INT64 -> CodedInputStream::readInt64; + case UINT64 -> value -> new BigInteger(unsignedToString(value.readUInt64())); + case SINT64 -> CodedInputStream::readSInt64; + case FIXED32 -> value -> Integer.toUnsignedLong(value.readFixed32()); + case SFIXED32 -> CodedInputStream::readSFixed32; + case FIXED64 -> value -> new BigInteger(unsignedToString(value.readFixed64())); + case SFIXED64 -> CodedInputStream::readSFixed64; + case FLOAT -> CodedInputStream::readFloat; + case DOUBLE -> CodedInputStream::readDouble; + default -> throw new IllegalStateException(String.format("Unexpected type [%s] was received for field [%s]", + protoType.getSimpleName(), protoField.getFieldName())); + }; + return resolveFieldValue(protoField, processRepeatedValues(inputStream, valueReader), value -> value); + } else { + List<Integer> values = processRepeatedValues(inputStream, CodedInputStream::readEnum); + return resolveFieldValue(protoField, values, value -> convertEnum(value, protoType)); + } + } + /** * Converts a Length-Delimited field value into it's suitable data type. * @@ -197,6 +233,10 @@ public class ProtobufDataConverter { valueConverter = value -> { try { Optional<DataType> recordDataType = rootRecordSchema.getDataType(protoField.getFieldName()); + if (protoField.isRepeatable()) { + final ArrayDataType arrayDataType = (ArrayDataType) recordDataType.get(); + recordDataType = Optional.ofNullable(arrayDataType.getElementType()); + } RecordSchema recordSchema = recordDataType.map(dataType -> ((RecordDataType) dataType).getChildSchema()).orElse(generateRecordSchema(messageType.getType().toString())); return createRecord(messageType, value, recordSchema); @@ -220,7 +260,7 @@ public class ProtobufDataConverter { final String typeName = protoField.getProtoType().getSimpleName(); final Function<Integer, Object> valueConverter = switch (FieldType.findValue(typeName)) { - case FIXED32 -> value -> Long.parseLong(unsignedToString(value)); + case FIXED32 -> Integer::toUnsignedLong; case SFIXED32 -> value -> value; case FLOAT -> Float::intBitsToFloat; default -> @@ -276,11 +316,7 @@ public class ProtobufDataConverter { " [%s] is not Varint field type", protoField.getFieldName(), protoType.getSimpleName())); }; } else { - valueConverter = value -> { - final EnumType enumType = (EnumType) schema.getType(protoType); - Objects.requireNonNull(enumType, String.format("Enum with name [%s] not found in the provided proto files", protoType)); - return enumType.constant(Integer.parseInt(value.toString())).getName(); - }; + valueConverter = value -> convertEnum(value.intValue(), protoType); } return resolveFieldValue(protoField, values, valueConverter); @@ -297,7 +333,7 @@ public class ProtobufDataConverter { } if (!protoField.isRepeatable()) { - return resultValues.get(0); + return resultValues.getFirst(); } else { return resultValues.toArray(); } @@ -327,6 +363,12 @@ public class ProtobufDataConverter { return mapResult; } + private String convertEnum(Integer value, ProtoType protoType) { + final EnumType enumType = (EnumType) schema.getType(protoType); + Objects.requireNonNull(enumType, String.format("Enum with name [%s] not found in the provided proto files", protoType)); + return enumType.constant(value).getName(); + } + /** * Process a 'google.protobuf.Any' typed field. The method gets the schema for the message type provided in the 'type_url' property * and parse the serialized message from the 'value' field. The result record will contain only the parsed message's fields. @@ -368,4 +410,28 @@ public class ProtobufDataConverter { private String getQualifiedTypeName(String typeName) { return typeName.substring(typeName.lastIndexOf('/') + 1); } + + private <T> List<T> processRepeatedValues(CodedInputStream input, ValueReader<CodedInputStream, T> valueReader) { + List<T> result = new ArrayList<>(); + try { + while (input.getBytesUntilLimit() > 0) { + result.add(valueReader.apply(input)); + } + } catch (Exception e) { + throw new IllegalStateException("Unable to parse repeated field", e); + } + return result; + } + + private boolean isLengthDelimitedType(ProtoField protoField) { + boolean lengthDelimitedScalarType = false; + final ProtoType protoType = protoField.getProtoType(); + + if (protoType.isScalar()) { + final FieldType fieldType = FieldType.findValue(protoType.getSimpleName()); + lengthDelimitedScalarType = fieldType.equals(STRING) || fieldType.equals(BYTES); + } + + return lengthDelimitedScalarType || schema.getType(protoType) instanceof MessageType; + } } diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.proto b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ValueReader.java similarity index 50% copy from nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.proto copy to nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ValueReader.java index a6ddec0e61..cff78dea51 100644 --- a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.proto +++ b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/main/java/org/apache/nifi/services/protobuf/converter/ValueReader.java @@ -14,40 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -syntax = "proto3"; +package org.apache.nifi.services.protobuf.converter; -message Proto3Message { - bool booleanField = 1; - string stringField = 2; - int32 int32Field = 3; - uint32 uint32Field = 4; - sint32 sint32Field = 5; - fixed32 fixed32Field = 6; - sfixed32 sfixed32Field = 7; - double doubleField = 8; - float floatField = 9; - bytes bytesField = 10; - int64 int64Field = 11; - uint64 uint64Field = 12; - sint64 sint64Field = 13; - fixed64 fixed64Field = 14; - sfixed64 sfixed64Field = 15; - NestedMessage nestedMessage = 16; -} +import java.io.IOException; -message NestedMessage { - TestEnum testEnum = 20; - repeated string repeatedField = 21; - oneof oneOfField { - string stringOption = 22; - bool booleanOption = 23; - int32 int32Option = 24; - } - map<string, int32> testMap = 25; -} +@FunctionalInterface +interface ValueReader<T, R> { -enum TestEnum { - ENUM_VALUE_1 = 0; - ENUM_VALUE_2 = 1; - ENUM_VALUE_3 = 2; -} \ No newline at end of file + R apply(T t) throws IOException; + +} diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/ProtoTestUtil.java b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/ProtoTestUtil.java index 4a10c0ecfd..c0d273da95 100644 --- a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/ProtoTestUtil.java +++ b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/ProtoTestUtil.java @@ -44,6 +44,12 @@ public class ProtoTestUtil { return schemaLoader.loadSchema(); } + public static Schema loadRepeatedProto3TestSchema() { + final SchemaLoader schemaLoader = new SchemaLoader(FileSystems.getDefault()); + schemaLoader.initRoots(Collections.singletonList(Location.get(BASE_TEST_PATH + "test_repeated_proto3.proto")), Collections.emptyList()); + return schemaLoader.loadSchema(); + } + public static Schema loadProto2TestSchema() { final SchemaLoader schemaLoader = new SchemaLoader(FileSystems.getDefault()); schemaLoader.initRoots(Arrays.asList( @@ -76,13 +82,10 @@ public class ProtoTestUtil { DynamicMessage nestedMessage = DynamicMessage .newBuilder(nestedMessageDescriptor) .setField(nestedMessageDescriptor.findFieldByNumber(20), enumValueDescriptor.findValueByNumber(2)) - .addRepeatedField(nestedMessageDescriptor.findFieldByNumber(21), "Repeated 1") - .addRepeatedField(nestedMessageDescriptor.findFieldByNumber(21), "Repeated 2") - .addRepeatedField(nestedMessageDescriptor.findFieldByNumber(21), "Repeated 3") + .setField(nestedMessageDescriptor.findFieldByNumber(21), Arrays.asList(mapEntry1, mapEntry2)) .setField(nestedMessageDescriptor.findFieldByNumber(22), "One Of Option") .setField(nestedMessageDescriptor.findFieldByNumber(23), true) .setField(nestedMessageDescriptor.findFieldByNumber(24), 3) - .setField(nestedMessageDescriptor.findFieldByNumber(25), Arrays.asList(mapEntry1, mapEntry2)) .build(); DynamicMessage message = DynamicMessage @@ -108,6 +111,64 @@ public class ProtoTestUtil { return message.toByteString().newInput(); } + public static InputStream generateInputDataForRepeatedProto3() throws IOException, Descriptors.DescriptorValidationException { + DescriptorProtos.FileDescriptorSet descriptorSet = DescriptorProtos.FileDescriptorSet.parseFrom(new FileInputStream(BASE_TEST_PATH + "test_repeated_proto3.desc")); + Descriptors.FileDescriptor fileDescriptor = Descriptors.FileDescriptor.buildFrom(descriptorSet.getFile(0), new Descriptors.FileDescriptor[0]); + + Descriptors.Descriptor messageDescriptor = fileDescriptor.findMessageTypeByName("RootMessage"); + Descriptors.Descriptor repeatedMessageDescriptor = fileDescriptor.findMessageTypeByName("RepeatedMessage"); + Descriptors.EnumDescriptor enumValueDescriptor = fileDescriptor.findEnumTypeByName("TestEnum"); + + DynamicMessage repeatedMessage1 = DynamicMessage + .newBuilder(repeatedMessageDescriptor) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(1), true) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(1), false) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(2), "Test text1") + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(2), "Test text2") + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(3), Integer.MAX_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(3), Integer.MAX_VALUE - 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(4), -1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(4), -2) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(5), Integer.MIN_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(5), Integer.MIN_VALUE + 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(6), -2) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(6), -3) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(7), Integer.MAX_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(7), Integer.MAX_VALUE - 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(8), Double.MAX_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(8), Double.MAX_VALUE - 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(9), Float.MAX_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(9), Float.MAX_VALUE - 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(10), "Test bytes1".getBytes()) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(10), "Test bytes2".getBytes()) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(11), Long.MAX_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(11), Long.MAX_VALUE - 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(12), -1L) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(12), -2L) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(13), Long.MIN_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(13), Long.MIN_VALUE + 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(14), -2L) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(14), -1L) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(15), Long.MAX_VALUE) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(15), Long.MAX_VALUE - 1) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(16), enumValueDescriptor.findValueByNumber(1)) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(16), enumValueDescriptor.findValueByNumber(2)) + .build(); + + DynamicMessage repeatedMessage2 = DynamicMessage + .newBuilder(repeatedMessageDescriptor) + .addRepeatedField(repeatedMessageDescriptor.findFieldByNumber(1), true) + .build(); + + DynamicMessage rootMessage = DynamicMessage + .newBuilder(messageDescriptor) + .addRepeatedField(messageDescriptor.findFieldByNumber(1), repeatedMessage1) + .addRepeatedField(messageDescriptor.findFieldByNumber(1), repeatedMessage2) + .build(); + + return rootMessage.toByteString().newInput(); + } + public static InputStream generateInputDataForProto2() throws IOException, Descriptors.DescriptorValidationException { DescriptorProtos.FileDescriptorSet anyDescriptorSet = DescriptorProtos.FileDescriptorSet.parseFrom(new FileInputStream(BASE_TEST_PATH + "google/protobuf/any.desc")); Descriptors.FileDescriptor anyDesc = Descriptors.FileDescriptor.buildFrom(anyDescriptorSet.getFile(0), new Descriptors.FileDescriptor[]{}); diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/converter/TestProtobufDataConverter.java b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/converter/TestProtobufDataConverter.java index 9b4bdabe78..7aeafa895a 100644 --- a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/converter/TestProtobufDataConverter.java +++ b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/converter/TestProtobufDataConverter.java @@ -20,7 +20,6 @@ import com.google.protobuf.Descriptors; import com.squareup.wire.schema.Schema; import org.apache.nifi.serialization.record.MapRecord; import org.apache.nifi.serialization.record.RecordSchema; -import org.apache.nifi.serialization.record.util.DataTypeUtils; import org.apache.nifi.services.protobuf.ProtoTestUtil; import org.apache.nifi.services.protobuf.schema.ProtoSchemaParser; import org.junit.jupiter.api.Test; @@ -31,6 +30,7 @@ import java.util.Map; import static org.apache.nifi.services.protobuf.ProtoTestUtil.loadProto2TestSchema; import static org.apache.nifi.services.protobuf.ProtoTestUtil.loadProto3TestSchema; +import static org.apache.nifi.services.protobuf.ProtoTestUtil.loadRepeatedProto3TestSchema; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; @@ -58,22 +58,58 @@ public class TestProtobufDataConverter { assertEquals(Float.MAX_VALUE, record.getValue("floatField")); assertArrayEquals("Test bytes".getBytes(), (byte[]) record.getValue("bytesField")); assertEquals(Long.MAX_VALUE, record.getValue("int64Field")); - assertEquals(new BigInteger("18446744073709551615"), DataTypeUtils.toBigInt(record.getValue("uint64Field"), "field12")); + assertEquals(new BigInteger("18446744073709551615"), record.getValue("uint64Field")); assertEquals(Long.MIN_VALUE, record.getValue("sint64Field")); - assertEquals(new BigInteger("18446744073709551614"), DataTypeUtils.toBigInt(record.getValue("fixed64Field"), "field14")); + assertEquals(new BigInteger("18446744073709551614"), record.getValue("fixed64Field")); assertEquals(Long.MAX_VALUE, record.getValue("sfixed64Field")); final MapRecord nestedRecord = (MapRecord) record.getValue("nestedMessage"); assertEquals("ENUM_VALUE_3", nestedRecord.getValue("testEnum")); - assertArrayEquals(new Object[]{"Repeated 1", "Repeated 2", "Repeated 3"}, (Object[]) nestedRecord.getValue("repeatedField")); + assertEquals(Map.of("test_key_entry1", 101, "test_key_entry2", 202), nestedRecord.getValue("testMap")); // assert only one field is set in the OneOf field assertNull(nestedRecord.getValue("stringOption")); assertNull(nestedRecord.getValue("booleanOption")); assertEquals(3, nestedRecord.getValue("int32Option")); + } - assertEquals(Map.of("test_key_entry1", 101, "test_key_entry2", 202), nestedRecord.getValue("testMap")); + @Test + public void testDataConverterForRepeatedProto3() throws Descriptors.DescriptorValidationException, IOException { + final Schema schema = loadRepeatedProto3TestSchema(); + final RecordSchema recordSchema = new ProtoSchemaParser(schema).createSchema("RootMessage"); + + final ProtobufDataConverter dataConverter = new ProtobufDataConverter(schema, "RootMessage", recordSchema, false, false); + final MapRecord record = dataConverter.createRecord(ProtoTestUtil.generateInputDataForRepeatedProto3()); + + final Object[] repeatedMessage = (Object[]) record.getValue("repeatedMessage"); + final MapRecord record1 = (MapRecord) repeatedMessage[0]; + + assertArrayEquals(new Object[]{true, false}, (Object[]) record1.getValue("booleanField")); + assertArrayEquals(new Object[]{"Test text1", "Test text2"}, (Object[]) record1.getValue("stringField")); + assertArrayEquals(new Object[]{Integer.MAX_VALUE, Integer.MAX_VALUE - 1}, (Object[]) record1.getValue("int32Field")); + assertArrayEquals(new Object[]{4294967295L, 4294967294L}, (Object[]) record1.getValue("uint32Field")); + assertArrayEquals(new Object[]{Integer.MIN_VALUE, Integer.MIN_VALUE + 1}, (Object[]) record1.getValue("sint32Field")); + assertArrayEquals(new Object[]{4294967294L, 4294967293L}, (Object[]) record1.getValue("fixed32Field")); + assertArrayEquals(new Object[]{Integer.MAX_VALUE, Integer.MAX_VALUE - 1}, (Object[]) record1.getValue("sfixed32Field")); + assertArrayEquals(new Object[]{Double.MAX_VALUE, Double.MAX_VALUE - 1}, (Object[]) record1.getValue("doubleField")); + assertArrayEquals(new Object[]{Float.MAX_VALUE, Float.MAX_VALUE - 1}, (Object[]) record1.getValue("floatField")); + assertArrayEquals(new Object[]{Long.MAX_VALUE, Long.MAX_VALUE - 1}, (Object[]) record1.getValue("int64Field")); + assertArrayEquals(new Object[]{Long.MIN_VALUE, Long.MIN_VALUE + 1}, (Object[]) record1.getValue("sint64Field")); + assertArrayEquals(new Object[]{Long.MAX_VALUE, Long.MAX_VALUE - 1}, (Object[]) record1.getValue("sfixed64Field")); + assertArrayEquals(new Object[]{"ENUM_VALUE_2", "ENUM_VALUE_3"}, (Object[]) record1.getValue("testEnum")); + + final Object[] uint64FieldValues = (Object[]) record1.getValue("uint64Field"); + assertEquals(new BigInteger("18446744073709551615"), uint64FieldValues[0]); + assertEquals(new BigInteger("18446744073709551614"), uint64FieldValues[1]); + + final Object[] bytesFieldValues = (Object[]) record1.getValue("bytesField"); + assertArrayEquals("Test bytes1".getBytes(), (byte[]) bytesFieldValues[0]); + assertArrayEquals("Test bytes2".getBytes(), (byte[]) bytesFieldValues[1]); + + final MapRecord record2 = (MapRecord) repeatedMessage[1]; + + assertArrayEquals(new Object[]{true}, (Object[]) record2.getValue("booleanField")); } @Test diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/schema/TestProtoSchemaParser.java b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/schema/TestProtoSchemaParser.java index d313bb595c..42bb858eab 100644 --- a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/schema/TestProtoSchemaParser.java +++ b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/java/org/apache/nifi/services/protobuf/schema/TestProtoSchemaParser.java @@ -20,12 +20,15 @@ import org.apache.nifi.serialization.SimpleRecordSchema; import org.apache.nifi.serialization.record.RecordField; import org.apache.nifi.serialization.record.RecordFieldType; import org.apache.nifi.serialization.record.RecordSchema; +import org.apache.nifi.serialization.record.type.ArrayDataType; +import org.apache.nifi.serialization.record.type.RecordDataType; import org.junit.jupiter.api.Test; import java.util.Arrays; import static org.apache.nifi.services.protobuf.ProtoTestUtil.loadProto2TestSchema; import static org.apache.nifi.services.protobuf.ProtoTestUtil.loadProto3TestSchema; +import static org.apache.nifi.services.protobuf.ProtoTestUtil.loadRepeatedProto3TestSchema; import static org.junit.jupiter.api.Assertions.assertEquals; public class TestProtoSchemaParser { @@ -52,7 +55,6 @@ public class TestProtoSchemaParser { new RecordField("sfixed64Field", RecordFieldType.LONG.getDataType()), new RecordField("nestedMessage", RecordFieldType.RECORD.getRecordDataType(new SimpleRecordSchema(Arrays.asList( new RecordField("testEnum", RecordFieldType.ENUM.getEnumDataType(Arrays.asList("ENUM_VALUE_1", "ENUM_VALUE_2", "ENUM_VALUE_3"))), - new RecordField("repeatedField", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.STRING.getDataType())), new RecordField("testMap", RecordFieldType.MAP.getMapDataType(RecordFieldType.INT.getDataType())), new RecordField("stringOption", RecordFieldType.STRING.getDataType()), new RecordField("booleanOption", RecordFieldType.BOOLEAN.getDataType()), @@ -64,6 +66,38 @@ public class TestProtoSchemaParser { assertEquals(expected, actual); } + @Test + public void testSchemaParserForRepeatedProto3() { + final ProtoSchemaParser schemaParser = new ProtoSchemaParser(loadRepeatedProto3TestSchema()); + + final SimpleRecordSchema expected = + new SimpleRecordSchema(Arrays.asList( + new RecordField("booleanField", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BOOLEAN.getDataType())), + new RecordField("stringField", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.STRING.getDataType())), + new RecordField("int32Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.INT.getDataType())), + new RecordField("uint32Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.LONG.getDataType())), + new RecordField("sint32Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.LONG.getDataType())), + new RecordField("fixed32Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.LONG.getDataType())), + new RecordField("sfixed32Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.INT.getDataType())), + new RecordField("doubleField", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.DOUBLE.getDataType())), + new RecordField("floatField", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.FLOAT.getDataType())), + new RecordField("bytesField", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BYTE.getDataType()))), + new RecordField("int64Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.LONG.getDataType())), + new RecordField("uint64Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BIGINT.getDataType())), + new RecordField("sint64Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.LONG.getDataType())), + new RecordField("fixed64Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BIGINT.getDataType())), + new RecordField("sfixed64Field", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.LONG.getDataType())), + new RecordField("testEnum", RecordFieldType.ARRAY.getArrayDataType( + RecordFieldType.ENUM.getEnumDataType(Arrays.asList("ENUM_VALUE_1", "ENUM_VALUE_2", "ENUM_VALUE_3")))) + )); + + final RecordSchema actual = schemaParser.createSchema("RootMessage"); + final ArrayDataType arrayDataType = (ArrayDataType) actual.getField("repeatedMessage").get().getDataType(); + final RecordDataType recordDataType = (RecordDataType) arrayDataType.getElementType(); + + assertEquals(expected, recordDataType.getChildSchema()); + } + @Test public void testSchemaParserForProto2() { final ProtoSchemaParser schemaParser = new ProtoSchemaParser(loadProto2TestSchema()); diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.desc b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.desc index a2316f3f87..1dbfb60613 100644 Binary files a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.desc and b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.desc differ diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.proto b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.proto index a6ddec0e61..3e7a736cd3 100644 --- a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.proto +++ b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.proto @@ -37,13 +37,12 @@ message Proto3Message { message NestedMessage { TestEnum testEnum = 20; - repeated string repeatedField = 21; + map<string, int32> testMap = 21; oneof oneOfField { string stringOption = 22; bool booleanOption = 23; int32 int32Option = 24; } - map<string, int32> testMap = 25; } enum TestEnum { diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_repeated_proto3.desc b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_repeated_proto3.desc new file mode 100644 index 0000000000..70811cb388 Binary files /dev/null and b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_repeated_proto3.desc differ diff --git a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.proto b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_repeated_proto3.proto similarity index 56% copy from nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.proto copy to nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_repeated_proto3.proto index a6ddec0e61..6af31f5bc2 100644 --- a/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_proto3.proto +++ b/nifi-extension-bundles/nifi-protobuf-bundle/nifi-protobuf-services/src/test/resources/test_repeated_proto3.proto @@ -16,34 +16,27 @@ */ syntax = "proto3"; -message Proto3Message { - bool booleanField = 1; - string stringField = 2; - int32 int32Field = 3; - uint32 uint32Field = 4; - sint32 sint32Field = 5; - fixed32 fixed32Field = 6; - sfixed32 sfixed32Field = 7; - double doubleField = 8; - float floatField = 9; - bytes bytesField = 10; - int64 int64Field = 11; - uint64 uint64Field = 12; - sint64 sint64Field = 13; - fixed64 fixed64Field = 14; - sfixed64 sfixed64Field = 15; - NestedMessage nestedMessage = 16; +message RootMessage { + repeated RepeatedMessage repeatedMessage = 1; } -message NestedMessage { - TestEnum testEnum = 20; - repeated string repeatedField = 21; - oneof oneOfField { - string stringOption = 22; - bool booleanOption = 23; - int32 int32Option = 24; - } - map<string, int32> testMap = 25; +message RepeatedMessage { + repeated bool booleanField = 1; + repeated string stringField = 2; + repeated int32 int32Field = 3; + repeated uint32 uint32Field = 4; + repeated sint32 sint32Field = 5; + repeated fixed32 fixed32Field = 6; + repeated sfixed32 sfixed32Field = 7; + repeated double doubleField = 8; + repeated float floatField = 9; + repeated bytes bytesField = 10; + repeated int64 int64Field = 11; + repeated uint64 uint64Field = 12; + repeated sint64 sint64Field = 13; + repeated fixed64 fixed64Field = 14; + repeated sfixed64 sfixed64Field = 15; + repeated TestEnum testEnum = 16; } enum TestEnum {