This is an automated email from the ASF dual-hosted git repository.

mmack pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 6de9565fdcd [Java][Schemas] Improve performance of 
GetterBasedSchemaProvider#fromRowFunction (closes #27533) (#27534)
6de9565fdcd is described below

commit 6de9565fdcd8617d7b072f5cd323be8768b61d9e
Author: Moritz Mack <mm...@talend.com>
AuthorDate: Wed Aug 30 11:58:42 2023 +0200

    [Java][Schemas] Improve performance of 
GetterBasedSchemaProvider#fromRowFunction (closes #27533) (#27534)
---
 .../beam/sdk/schemas/FromRowUsingCreator.java      | 320 ++++++++++++---------
 1 file changed, 177 insertions(+), 143 deletions(-)

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java
index ab9a6317efc..53c098599c3 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java
@@ -17,10 +17,11 @@
  */
 package org.apache.beam.sdk.schemas;
 
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
 import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
 import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
 
-import java.lang.reflect.Type;
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
@@ -37,6 +38,7 @@ import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collec
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
+import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
 /** Function to convert a {@link Row} to a user type using a creator factory. 
*/
@@ -44,188 +46,220 @@ import 
org.checkerframework.checker.nullness.qual.Nullable;
   "nullness", // TODO(https://github.com/apache/beam/issues/20497)
   "rawtypes"
 })
-class FromRowUsingCreator<T> implements SerializableFunction<Row, T> {
+class FromRowUsingCreator<T> implements SerializableFunction<Row, T>, 
Function<Row, T> {
   private final Class<T> clazz;
   private final GetterBasedSchemaProvider schemaProvider;
   private final Factory<SchemaUserTypeCreator> schemaTypeCreatorFactory;
-  private final Factory<List<FieldValueTypeInformation>> 
fieldValueTypeInformationFactory;
+
+  @SuppressFBWarnings("SE_TRANSIENT_FIELD_NOT_RESTORED")
+  private transient @MonotonicNonNull Function[] fieldConverters;
 
   public FromRowUsingCreator(Class<T> clazz, GetterBasedSchemaProvider 
schemaProvider) {
+    this(clazz, schemaProvider, new 
CachingFactory<>(schemaProvider::schemaTypeCreator), null);
+  }
+
+  private FromRowUsingCreator(
+      Class<T> clazz,
+      GetterBasedSchemaProvider schemaProvider,
+      Factory<SchemaUserTypeCreator> schemaTypeCreatorFactory,
+      @Nullable Function[] fieldConverters) {
     this.clazz = clazz;
     this.schemaProvider = schemaProvider;
-    this.schemaTypeCreatorFactory = new 
CachingFactory<>(schemaProvider::schemaTypeCreator);
-    this.fieldValueTypeInformationFactory =
-        new CachingFactory<>(schemaProvider::fieldValueTypeInformations);
+    this.schemaTypeCreatorFactory = schemaTypeCreatorFactory;
+    this.fieldConverters = fieldConverters;
   }
 
   @Override
-  public T apply(Row row) {
-    return fromRow(row, clazz, fieldValueTypeInformationFactory);
-  }
-
   @SuppressWarnings("unchecked")
-  public <ValueT> ValueT fromRow(
-      Row row, Class<ValueT> clazz, Factory<List<FieldValueTypeInformation>> 
typeFactory) {
+  public T apply(Row row) {
+    if (row == null) {
+      return null;
+    }
     if (row instanceof RowWithGetters) {
       Object target = ((RowWithGetters) row).getGetterTarget();
       if (target.getClass().equals(clazz)) {
         // Efficient path: simply extract the underlying object instead of 
creating a new one.
-        return (ValueT) target;
+        return (T) target;
       }
     }
+    if (fieldConverters == null) {
+      initFieldConverters(row.getSchema());
+    }
+    checkState(fieldConverters.length == row.getFieldCount(), "Unexpected 
field count");
 
     Object[] params = new Object[row.getFieldCount()];
-    Schema schema = row.getSchema();
-    List<FieldValueTypeInformation> typeInformations = 
typeFactory.create(clazz, schema);
+    for (int i = 0; i < row.getFieldCount(); ++i) {
+      params[i] = fieldConverters[i].apply(row.getValue(i));
+    }
+    SchemaUserTypeCreator creator = schemaTypeCreatorFactory.create(clazz, 
row.getSchema());
+    return (T) creator.create(params);
+  }
+
+  private synchronized void initFieldConverters(Schema schema) {
+    if (fieldConverters == null) {
+      CachingFactory<List<FieldValueTypeInformation>> typeFactory =
+          new CachingFactory<>(schemaProvider::fieldValueTypeInformations);
+      fieldConverters = fieldConverters(clazz, schema, typeFactory);
+    }
+  }
+
+  private Function[] fieldConverters(
+      Class<?> clazz, Schema schema, Factory<List<FieldValueTypeInformation>> 
typeFactory) {
+    List<FieldValueTypeInformation> typeInfos = typeFactory.create(clazz, 
schema);
     checkState(
-        typeInformations.size() == row.getFieldCount(),
+        typeInfos.size() == schema.getFieldCount(),
         "Did not have a matching number of type informations and fields.");
-
-    for (int i = 0; i < row.getFieldCount(); ++i) {
-      FieldType type = schema.getField(i).getType();
-      FieldValueTypeInformation typeInformation = 
checkNotNull(typeInformations.get(i));
-      params[i] =
-          fromValue(
-              type, row.getValue(i), typeInformation.getRawType(), 
typeInformation, typeFactory);
+    Function[] converters = new Function[schema.getFieldCount()];
+    for (int i = 0; i < converters.length; i++) {
+      converters[i] = fieldConverter(schema.getField(i).getType(), 
typeInfos.get(i), typeFactory);
     }
+    return converters;
+  }
 
-    SchemaUserTypeCreator creator = schemaTypeCreatorFactory.create(clazz, 
schema);
-    return (ValueT) creator.create(params);
+  private static boolean needsConversion(FieldType type) {
+    TypeName typeName = type.getTypeName();
+    return typeName.equals(TypeName.ROW)
+        || typeName.isLogicalType()
+        || ((typeName.equals(TypeName.ARRAY) || 
typeName.equals(TypeName.ITERABLE))
+            && needsConversion(type.getCollectionElementType()))
+        || (typeName.equals(TypeName.MAP)
+            && (needsConversion(type.getMapKeyType()) || 
needsConversion(type.getMapValueType())));
   }
 
-  @SuppressWarnings("unchecked")
-  private @Nullable <ValueT> ValueT fromValue(
+  private Function fieldConverter(
       FieldType type,
-      ValueT value,
-      Type fieldType,
-      FieldValueTypeInformation fieldValueTypeInformation,
+      FieldValueTypeInformation typeInfo,
       Factory<List<FieldValueTypeInformation>> typeFactory) {
-    FieldValueTypeInformation elementType = 
fieldValueTypeInformation.getElementType();
-    FieldValueTypeInformation keyType = 
fieldValueTypeInformation.getMapKeyType();
-    FieldValueTypeInformation valueType = 
fieldValueTypeInformation.getMapValueType();
-    if (value == null) {
-      return null;
-    }
-    if (TypeName.ROW.equals(type.getTypeName())) {
-      return (ValueT) fromRow((Row) value, (Class) fieldType, typeFactory);
+    if (!needsConversion(type)) {
+      return FieldConverter.IDENTITY;
+    } else if (TypeName.ROW.equals(type.getTypeName())) {
+      Function[] converters =
+          fieldConverters(typeInfo.getRawType(), type.getRowSchema(), 
typeFactory);
+      return new FromRowUsingCreator(
+          typeInfo.getRawType(), schemaProvider, schemaTypeCreatorFactory, 
converters);
     } else if (TypeName.ARRAY.equals(type.getTypeName())) {
-      return (ValueT)
-          fromCollectionValue(
-              type.getCollectionElementType(), (Collection) value, 
elementType, typeFactory);
+      return new ConvertCollection(
+          fieldConverter(type.getCollectionElementType(), 
typeInfo.getElementType(), typeFactory));
     } else if (TypeName.ITERABLE.equals(type.getTypeName())) {
-      return (ValueT)
-          fromIterableValue(
-              type.getCollectionElementType(), (Iterable) value, elementType, 
typeFactory);
-    }
-    if (TypeName.MAP.equals(type.getTypeName())) {
-      return (ValueT)
-          fromMapValue(
-              type.getMapKeyType(),
-              type.getMapValueType(),
-              (Map) value,
-              keyType,
-              valueType,
-              typeFactory);
-    } else {
-      if (type.isLogicalType(OneOfType.IDENTIFIER)) {
-        OneOfType oneOfType = type.getLogicalType(OneOfType.class);
-        EnumerationType oneOfEnum = oneOfType.getCaseEnumType();
-        OneOfType.Value oneOfValue = (OneOfType.Value) value;
-        FieldValueTypeInformation oneOfFieldValueTypeInformation =
-            checkNotNull(
-                fieldValueTypeInformation
-                    .getOneOfTypes()
-                    .get(oneOfEnum.toString(oneOfValue.getCaseType())));
-        Object fromValue =
-            fromValue(
-                oneOfType.getFieldType(oneOfValue),
-                oneOfValue.getValue(),
-                oneOfFieldValueTypeInformation.getRawType(),
-                oneOfFieldValueTypeInformation,
-                typeFactory);
-        return (ValueT) oneOfType.createValue(oneOfValue.getCaseType(), 
fromValue);
-      } else if (type.getTypeName().isLogicalType()) {
-        Schema.LogicalType<ValueT, ValueT> logicalType =
-            (Schema.LogicalType<ValueT, ValueT>) type.getLogicalType();
-        return logicalType.toBaseType(value);
+      return new ConvertIterable(
+          fieldConverter(type.getCollectionElementType(), 
typeInfo.getElementType(), typeFactory));
+    } else if (TypeName.MAP.equals(type.getTypeName())) {
+      return new ConvertMap(
+          fieldConverter(type.getMapKeyType(), typeInfo.getMapKeyType(), 
typeFactory),
+          fieldConverter(type.getMapValueType(), typeInfo.getMapValueType(), 
typeFactory));
+    } else if (type.isLogicalType(OneOfType.IDENTIFIER)) {
+      OneOfType oneOfType = type.getLogicalType(OneOfType.class);
+      Schema schema = oneOfType.getOneOfSchema();
+      Map<Integer, Function> readers = 
Maps.newHashMapWithExpectedSize(schema.getFieldCount());
+      oneOfType
+          .getCaseEnumType()
+          .getValuesMap()
+          .forEach(
+              (name, id) -> {
+                FieldType caseType = schema.getField(name).getType();
+                FieldValueTypeInformation caseTypeInfo =
+                    checkNotNull(typeInfo.getOneOfTypes().get(name));
+                readers.put(id, fieldConverter(caseType, caseTypeInfo, 
typeFactory));
+              });
+      return new ConvertOneOf(oneOfType, readers);
+    } else if (type.getTypeName().isLogicalType()) {
+      return new ConvertLogicalType<>(type.getLogicalType());
+    }
+    return FieldConverter.IDENTITY;
+  }
+
+  private interface FieldConverter<FieldT, ValueT>
+      extends SerializableFunction<FieldT, ValueT>, Function<FieldT, ValueT> {
+    Function<Object, Object> IDENTITY = v -> v;
+
+    ValueT convert(FieldT field);
+
+    @Override
+    default @Nullable ValueT apply(@Nullable FieldT fieldValue) {
+      return fieldValue == null ? null : convert(fieldValue);
+    }
+  }
+
+  private static class ConvertCollection implements FieldConverter<Collection, 
Collection> {
+    final Function converter;
+
+    ConvertCollection(Function converter) {
+      this.converter = converter;
+    }
+
+    @Override
+    public Collection convert(Collection collection) {
+      if (collection instanceof List) {
+        // For performance reasons if the input is a list, make sure that we 
produce a list.
+        // Otherwise Row unwrapping is forced to physically copy the 
collection into a new List
+        // object.
+        return Lists.transform((List) collection, converter);
+      } else {
+        return Collections2.transform(collection, converter);
       }
-      return value;
     }
   }
 
-  private static <SourceT, DestT> Collection<DestT> transformCollection(
-      Collection<SourceT> collection, Function<SourceT, DestT> function) {
-    if (collection instanceof List) {
-      // For performance reasons if the input is a list, make sure that we 
produce a list. Otherwise
-      // Row unwrapping
-      // is forced to physically copy the collection into a new List object.
-      return Lists.transform((List) collection, function);
-    } else {
-      return Collections2.transform(collection, function);
+  private static class ConvertIterable implements FieldConverter<Iterable, 
Iterable> {
+    final Function converter;
+
+    ConvertIterable(Function converter) {
+      this.converter = converter;
+    }
+
+    @Override
+    public Iterable convert(Iterable iterable) {
+      return Iterables.transform(iterable, converter);
     }
   }
 
-  @SuppressWarnings("unchecked")
-  private <ElementT> Collection fromCollectionValue(
-      FieldType elementType,
-      Collection<ElementT> rowCollection,
-      FieldValueTypeInformation elementTypeInformation,
-      Factory<List<FieldValueTypeInformation>> typeFactory) {
-    return transformCollection(
-        rowCollection,
-        element ->
-            fromValue(
-                elementType,
-                element,
-                elementTypeInformation.getType().getType(),
-                elementTypeInformation,
-                typeFactory));
+  private static class ConvertMap implements FieldConverter<Map, Map> {
+    final Function keyConverter, valueConverter;
+
+    ConvertMap(Function keyConverter, Function valueConverter) {
+      this.keyConverter = keyConverter;
+      this.valueConverter = valueConverter;
+    }
+
+    @Override
+    public Map convert(Map field) {
+      Map result = Maps.newHashMapWithExpectedSize(field.size());
+      field.forEach((k, v) -> result.put(keyConverter.apply(k), 
valueConverter.apply(v)));
+      return result;
+    }
   }
 
-  @SuppressWarnings("unchecked")
-  private <ElementT> Iterable fromIterableValue(
-      FieldType elementType,
-      Iterable<ElementT> rowIterable,
-      FieldValueTypeInformation elementTypeInformation,
-      Factory<List<FieldValueTypeInformation>> typeFactory) {
-    return Iterables.transform(
-        rowIterable,
-        element ->
-            fromValue(
-                elementType,
-                element,
-                elementTypeInformation.getType().getType(),
-                elementTypeInformation,
-                typeFactory));
+  private static class ConvertOneOf implements FieldConverter<OneOfType.Value, 
OneOfType.Value> {
+    final OneOfType oneOfType;
+    final Map<Integer, Function> converters;
+
+    ConvertOneOf(OneOfType oneOfType, Map<Integer, Function> converters) {
+      this.oneOfType = oneOfType;
+      this.converters = converters;
+    }
+
+    @Override
+    public OneOfType.Value convert(OneOfType.Value field) {
+      EnumerationType.Value caseType = field.getCaseType();
+      Function converter =
+          checkStateNotNull(
+              converters.get(caseType.getValue()), "Missing OneOf converter 
for case %s.");
+      return oneOfType.createValue(caseType, 
converter.apply(field.getValue()));
+    }
   }
 
-  @SuppressWarnings("unchecked")
-  private Map<?, ?> fromMapValue(
-      FieldType keyType,
-      FieldType valueType,
-      Map<?, ?> map,
-      FieldValueTypeInformation keyTypeInformation,
-      FieldValueTypeInformation valueTypeInformation,
-      Factory<List<FieldValueTypeInformation>> typeFactory) {
-    Map newMap = Maps.newHashMap();
-    for (Map.Entry<?, ?> entry : map.entrySet()) {
-      Object key =
-          fromValue(
-              keyType,
-              entry.getKey(),
-              keyTypeInformation.getType().getType(),
-              keyTypeInformation,
-              typeFactory);
-      Object value =
-          fromValue(
-              valueType,
-              entry.getValue(),
-              valueTypeInformation.getType().getType(),
-              valueTypeInformation,
-              typeFactory);
-      newMap.put(key, value);
-    }
-    return newMap;
+  private static class ConvertLogicalType<FieldT, ValueT>
+      implements FieldConverter<FieldT, ValueT> {
+    final Schema.LogicalType<FieldT, ValueT> logicalType;
+
+    ConvertLogicalType(Schema.LogicalType<FieldT, ValueT> logicalType) {
+      this.logicalType = logicalType;
+    }
+
+    @Override
+    public ValueT convert(FieldT field) {
+      return logicalType.toBaseType(field);
+    }
   }
 
   @Override

Reply via email to