This is an automated email from the ASF dual-hosted git repository. mmack pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 6de9565fdcd [Java][Schemas] Improve performance of GetterBasedSchemaProvider#fromRowFunction (closes #27533) (#27534) 6de9565fdcd is described below commit 6de9565fdcd8617d7b072f5cd323be8768b61d9e Author: Moritz Mack <mm...@talend.com> AuthorDate: Wed Aug 30 11:58:42 2023 +0200 [Java][Schemas] Improve performance of GetterBasedSchemaProvider#fromRowFunction (closes #27533) (#27534) --- .../beam/sdk/schemas/FromRowUsingCreator.java | 320 ++++++++++++--------- 1 file changed, 177 insertions(+), 143 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java index ab9a6317efc..53c098599c3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java @@ -17,10 +17,11 @@ */ package org.apache.beam.sdk.schemas; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; -import java.lang.reflect.Type; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.util.Collection; import java.util.List; import java.util.Map; @@ -37,6 +38,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collec import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** Function to convert a {@link Row} to a user type using a creator factory. */ @@ -44,188 +46,220 @@ import org.checkerframework.checker.nullness.qual.Nullable; "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" }) -class FromRowUsingCreator<T> implements SerializableFunction<Row, T> { +class FromRowUsingCreator<T> implements SerializableFunction<Row, T>, Function<Row, T> { private final Class<T> clazz; private final GetterBasedSchemaProvider schemaProvider; private final Factory<SchemaUserTypeCreator> schemaTypeCreatorFactory; - private final Factory<List<FieldValueTypeInformation>> fieldValueTypeInformationFactory; + + @SuppressFBWarnings("SE_TRANSIENT_FIELD_NOT_RESTORED") + private transient @MonotonicNonNull Function[] fieldConverters; public FromRowUsingCreator(Class<T> clazz, GetterBasedSchemaProvider schemaProvider) { + this(clazz, schemaProvider, new CachingFactory<>(schemaProvider::schemaTypeCreator), null); + } + + private FromRowUsingCreator( + Class<T> clazz, + GetterBasedSchemaProvider schemaProvider, + Factory<SchemaUserTypeCreator> schemaTypeCreatorFactory, + @Nullable Function[] fieldConverters) { this.clazz = clazz; this.schemaProvider = schemaProvider; - this.schemaTypeCreatorFactory = new CachingFactory<>(schemaProvider::schemaTypeCreator); - this.fieldValueTypeInformationFactory = - new CachingFactory<>(schemaProvider::fieldValueTypeInformations); + this.schemaTypeCreatorFactory = schemaTypeCreatorFactory; + this.fieldConverters = fieldConverters; } @Override - public T apply(Row row) { - return fromRow(row, clazz, fieldValueTypeInformationFactory); - } - @SuppressWarnings("unchecked") - public <ValueT> ValueT fromRow( - Row row, Class<ValueT> clazz, Factory<List<FieldValueTypeInformation>> typeFactory) { + public T apply(Row row) { + if (row == null) { + return null; + } if (row instanceof RowWithGetters) { Object target = ((RowWithGetters) row).getGetterTarget(); if (target.getClass().equals(clazz)) { // Efficient path: simply extract the underlying object instead of creating a new one. - return (ValueT) target; + return (T) target; } } + if (fieldConverters == null) { + initFieldConverters(row.getSchema()); + } + checkState(fieldConverters.length == row.getFieldCount(), "Unexpected field count"); Object[] params = new Object[row.getFieldCount()]; - Schema schema = row.getSchema(); - List<FieldValueTypeInformation> typeInformations = typeFactory.create(clazz, schema); + for (int i = 0; i < row.getFieldCount(); ++i) { + params[i] = fieldConverters[i].apply(row.getValue(i)); + } + SchemaUserTypeCreator creator = schemaTypeCreatorFactory.create(clazz, row.getSchema()); + return (T) creator.create(params); + } + + private synchronized void initFieldConverters(Schema schema) { + if (fieldConverters == null) { + CachingFactory<List<FieldValueTypeInformation>> typeFactory = + new CachingFactory<>(schemaProvider::fieldValueTypeInformations); + fieldConverters = fieldConverters(clazz, schema, typeFactory); + } + } + + private Function[] fieldConverters( + Class<?> clazz, Schema schema, Factory<List<FieldValueTypeInformation>> typeFactory) { + List<FieldValueTypeInformation> typeInfos = typeFactory.create(clazz, schema); checkState( - typeInformations.size() == row.getFieldCount(), + typeInfos.size() == schema.getFieldCount(), "Did not have a matching number of type informations and fields."); - - for (int i = 0; i < row.getFieldCount(); ++i) { - FieldType type = schema.getField(i).getType(); - FieldValueTypeInformation typeInformation = checkNotNull(typeInformations.get(i)); - params[i] = - fromValue( - type, row.getValue(i), typeInformation.getRawType(), typeInformation, typeFactory); + Function[] converters = new Function[schema.getFieldCount()]; + for (int i = 0; i < converters.length; i++) { + converters[i] = fieldConverter(schema.getField(i).getType(), typeInfos.get(i), typeFactory); } + return converters; + } - SchemaUserTypeCreator creator = schemaTypeCreatorFactory.create(clazz, schema); - return (ValueT) creator.create(params); + private static boolean needsConversion(FieldType type) { + TypeName typeName = type.getTypeName(); + return typeName.equals(TypeName.ROW) + || typeName.isLogicalType() + || ((typeName.equals(TypeName.ARRAY) || typeName.equals(TypeName.ITERABLE)) + && needsConversion(type.getCollectionElementType())) + || (typeName.equals(TypeName.MAP) + && (needsConversion(type.getMapKeyType()) || needsConversion(type.getMapValueType()))); } - @SuppressWarnings("unchecked") - private @Nullable <ValueT> ValueT fromValue( + private Function fieldConverter( FieldType type, - ValueT value, - Type fieldType, - FieldValueTypeInformation fieldValueTypeInformation, + FieldValueTypeInformation typeInfo, Factory<List<FieldValueTypeInformation>> typeFactory) { - FieldValueTypeInformation elementType = fieldValueTypeInformation.getElementType(); - FieldValueTypeInformation keyType = fieldValueTypeInformation.getMapKeyType(); - FieldValueTypeInformation valueType = fieldValueTypeInformation.getMapValueType(); - if (value == null) { - return null; - } - if (TypeName.ROW.equals(type.getTypeName())) { - return (ValueT) fromRow((Row) value, (Class) fieldType, typeFactory); + if (!needsConversion(type)) { + return FieldConverter.IDENTITY; + } else if (TypeName.ROW.equals(type.getTypeName())) { + Function[] converters = + fieldConverters(typeInfo.getRawType(), type.getRowSchema(), typeFactory); + return new FromRowUsingCreator( + typeInfo.getRawType(), schemaProvider, schemaTypeCreatorFactory, converters); } else if (TypeName.ARRAY.equals(type.getTypeName())) { - return (ValueT) - fromCollectionValue( - type.getCollectionElementType(), (Collection) value, elementType, typeFactory); + return new ConvertCollection( + fieldConverter(type.getCollectionElementType(), typeInfo.getElementType(), typeFactory)); } else if (TypeName.ITERABLE.equals(type.getTypeName())) { - return (ValueT) - fromIterableValue( - type.getCollectionElementType(), (Iterable) value, elementType, typeFactory); - } - if (TypeName.MAP.equals(type.getTypeName())) { - return (ValueT) - fromMapValue( - type.getMapKeyType(), - type.getMapValueType(), - (Map) value, - keyType, - valueType, - typeFactory); - } else { - if (type.isLogicalType(OneOfType.IDENTIFIER)) { - OneOfType oneOfType = type.getLogicalType(OneOfType.class); - EnumerationType oneOfEnum = oneOfType.getCaseEnumType(); - OneOfType.Value oneOfValue = (OneOfType.Value) value; - FieldValueTypeInformation oneOfFieldValueTypeInformation = - checkNotNull( - fieldValueTypeInformation - .getOneOfTypes() - .get(oneOfEnum.toString(oneOfValue.getCaseType()))); - Object fromValue = - fromValue( - oneOfType.getFieldType(oneOfValue), - oneOfValue.getValue(), - oneOfFieldValueTypeInformation.getRawType(), - oneOfFieldValueTypeInformation, - typeFactory); - return (ValueT) oneOfType.createValue(oneOfValue.getCaseType(), fromValue); - } else if (type.getTypeName().isLogicalType()) { - Schema.LogicalType<ValueT, ValueT> logicalType = - (Schema.LogicalType<ValueT, ValueT>) type.getLogicalType(); - return logicalType.toBaseType(value); + return new ConvertIterable( + fieldConverter(type.getCollectionElementType(), typeInfo.getElementType(), typeFactory)); + } else if (TypeName.MAP.equals(type.getTypeName())) { + return new ConvertMap( + fieldConverter(type.getMapKeyType(), typeInfo.getMapKeyType(), typeFactory), + fieldConverter(type.getMapValueType(), typeInfo.getMapValueType(), typeFactory)); + } else if (type.isLogicalType(OneOfType.IDENTIFIER)) { + OneOfType oneOfType = type.getLogicalType(OneOfType.class); + Schema schema = oneOfType.getOneOfSchema(); + Map<Integer, Function> readers = Maps.newHashMapWithExpectedSize(schema.getFieldCount()); + oneOfType + .getCaseEnumType() + .getValuesMap() + .forEach( + (name, id) -> { + FieldType caseType = schema.getField(name).getType(); + FieldValueTypeInformation caseTypeInfo = + checkNotNull(typeInfo.getOneOfTypes().get(name)); + readers.put(id, fieldConverter(caseType, caseTypeInfo, typeFactory)); + }); + return new ConvertOneOf(oneOfType, readers); + } else if (type.getTypeName().isLogicalType()) { + return new ConvertLogicalType<>(type.getLogicalType()); + } + return FieldConverter.IDENTITY; + } + + private interface FieldConverter<FieldT, ValueT> + extends SerializableFunction<FieldT, ValueT>, Function<FieldT, ValueT> { + Function<Object, Object> IDENTITY = v -> v; + + ValueT convert(FieldT field); + + @Override + default @Nullable ValueT apply(@Nullable FieldT fieldValue) { + return fieldValue == null ? null : convert(fieldValue); + } + } + + private static class ConvertCollection implements FieldConverter<Collection, Collection> { + final Function converter; + + ConvertCollection(Function converter) { + this.converter = converter; + } + + @Override + public Collection convert(Collection collection) { + if (collection instanceof List) { + // For performance reasons if the input is a list, make sure that we produce a list. + // Otherwise Row unwrapping is forced to physically copy the collection into a new List + // object. + return Lists.transform((List) collection, converter); + } else { + return Collections2.transform(collection, converter); } - return value; } } - private static <SourceT, DestT> Collection<DestT> transformCollection( - Collection<SourceT> collection, Function<SourceT, DestT> function) { - if (collection instanceof List) { - // For performance reasons if the input is a list, make sure that we produce a list. Otherwise - // Row unwrapping - // is forced to physically copy the collection into a new List object. - return Lists.transform((List) collection, function); - } else { - return Collections2.transform(collection, function); + private static class ConvertIterable implements FieldConverter<Iterable, Iterable> { + final Function converter; + + ConvertIterable(Function converter) { + this.converter = converter; + } + + @Override + public Iterable convert(Iterable iterable) { + return Iterables.transform(iterable, converter); } } - @SuppressWarnings("unchecked") - private <ElementT> Collection fromCollectionValue( - FieldType elementType, - Collection<ElementT> rowCollection, - FieldValueTypeInformation elementTypeInformation, - Factory<List<FieldValueTypeInformation>> typeFactory) { - return transformCollection( - rowCollection, - element -> - fromValue( - elementType, - element, - elementTypeInformation.getType().getType(), - elementTypeInformation, - typeFactory)); + private static class ConvertMap implements FieldConverter<Map, Map> { + final Function keyConverter, valueConverter; + + ConvertMap(Function keyConverter, Function valueConverter) { + this.keyConverter = keyConverter; + this.valueConverter = valueConverter; + } + + @Override + public Map convert(Map field) { + Map result = Maps.newHashMapWithExpectedSize(field.size()); + field.forEach((k, v) -> result.put(keyConverter.apply(k), valueConverter.apply(v))); + return result; + } } - @SuppressWarnings("unchecked") - private <ElementT> Iterable fromIterableValue( - FieldType elementType, - Iterable<ElementT> rowIterable, - FieldValueTypeInformation elementTypeInformation, - Factory<List<FieldValueTypeInformation>> typeFactory) { - return Iterables.transform( - rowIterable, - element -> - fromValue( - elementType, - element, - elementTypeInformation.getType().getType(), - elementTypeInformation, - typeFactory)); + private static class ConvertOneOf implements FieldConverter<OneOfType.Value, OneOfType.Value> { + final OneOfType oneOfType; + final Map<Integer, Function> converters; + + ConvertOneOf(OneOfType oneOfType, Map<Integer, Function> converters) { + this.oneOfType = oneOfType; + this.converters = converters; + } + + @Override + public OneOfType.Value convert(OneOfType.Value field) { + EnumerationType.Value caseType = field.getCaseType(); + Function converter = + checkStateNotNull( + converters.get(caseType.getValue()), "Missing OneOf converter for case %s."); + return oneOfType.createValue(caseType, converter.apply(field.getValue())); + } } - @SuppressWarnings("unchecked") - private Map<?, ?> fromMapValue( - FieldType keyType, - FieldType valueType, - Map<?, ?> map, - FieldValueTypeInformation keyTypeInformation, - FieldValueTypeInformation valueTypeInformation, - Factory<List<FieldValueTypeInformation>> typeFactory) { - Map newMap = Maps.newHashMap(); - for (Map.Entry<?, ?> entry : map.entrySet()) { - Object key = - fromValue( - keyType, - entry.getKey(), - keyTypeInformation.getType().getType(), - keyTypeInformation, - typeFactory); - Object value = - fromValue( - valueType, - entry.getValue(), - valueTypeInformation.getType().getType(), - valueTypeInformation, - typeFactory); - newMap.put(key, value); - } - return newMap; + private static class ConvertLogicalType<FieldT, ValueT> + implements FieldConverter<FieldT, ValueT> { + final Schema.LogicalType<FieldT, ValueT> logicalType; + + ConvertLogicalType(Schema.LogicalType<FieldT, ValueT> logicalType) { + this.logicalType = logicalType; + } + + @Override + public ValueT convert(FieldT field) { + return logicalType.toBaseType(field); + } } @Override