[ 
https://issues.apache.org/jira/browse/BEAM-4454?focusedWorklogId=171703&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-171703
 ]

ASF GitHub Bot logged work on BEAM-4454:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 03/Dec/18 20:46
            Start Date: 03/Dec/18 20:46
    Worklog Time Spent: 10m 
      Work Description: reuvenlax closed pull request #7181: [BEAM-4454] Add 
more AVRO utilities to convert between Beam and Avro.
URL: https://github.com/apache/beam/pull/7181
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AvroUtils.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AvroUtils.java
index b0a76976d15f..8b9f182d0848 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AvroUtils.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AvroUtils.java
@@ -20,7 +20,9 @@
 import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkNotNull;
 
-import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import java.math.BigDecimal;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -28,16 +30,67 @@
 import java.util.Map;
 import java.util.stream.Collectors;
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import org.apache.avro.Conversions;
+import org.apache.avro.LogicalType;
+import org.apache.avro.LogicalTypes;
+import org.apache.avro.Schema.Type;
 import org.apache.avro.generic.GenericEnumSymbol;
 import org.apache.avro.generic.GenericFixed;
 import org.apache.avro.generic.GenericRecord;
+import org.apache.avro.generic.GenericRecordBuilder;
+import org.apache.avro.reflect.ReflectData;
+import org.apache.avro.util.Utf8;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.schemas.Schema.TypeName;
 import org.apache.beam.sdk.values.Row;
+import org.joda.time.Instant;
+import org.joda.time.ReadableInstant;
 
 /** Utils to convert AVRO records to Beam rows. */
 @Experimental(Experimental.Kind.SCHEMAS)
 public class AvroUtils {
+  // Unwrap an AVRO schema into the base type an whether it is nullable.
+  static class TypeWithNullability {
+    public final org.apache.avro.Schema type;
+    public final boolean nullable;
+
+    TypeWithNullability(org.apache.avro.Schema avroSchema) {
+      if (avroSchema.getType() == org.apache.avro.Schema.Type.UNION) {
+        List<org.apache.avro.Schema> types = avroSchema.getTypes();
+
+        // optional fields in AVRO have form of:
+        // {"name": "foo", "type": ["null", "something"]}
+
+        // don't need recursion because nested unions aren't supported in AVRO
+        List<org.apache.avro.Schema> nonNullTypes =
+            types
+                .stream()
+                .filter(x -> x.getType() != org.apache.avro.Schema.Type.NULL)
+                .collect(Collectors.toList());
+
+        if (nonNullTypes.size() == types.size() || nonNullTypes.isEmpty()) {
+          // union without `null` or all 'null' union, keep as is.
+          type = avroSchema;
+          nullable = false;
+        } else if (nonNullTypes.size() > 1) {
+          type = org.apache.avro.Schema.createUnion(nonNullTypes);
+          nullable = true;
+        } else {
+          // One non-null type.
+          type = nonNullTypes.get(0);
+          nullable = true;
+        }
+      } else {
+        type = avroSchema;
+        nullable = false;
+      }
+    }
+  }
+
   private AvroUtils() {}
 
   /**
@@ -45,94 +98,332 @@ private AvroUtils() {}
    *
    * @param schema schema of type RECORD
    */
-  public static Schema toSchema(@Nonnull org.apache.avro.Schema schema) {
+  public static Schema toBeamSchema(org.apache.avro.Schema schema) {
     Schema.Builder builder = Schema.builder();
 
     for (org.apache.avro.Schema.Field field : schema.getFields()) {
-      org.apache.avro.Schema unwrapped = unwrapNullableSchema(field.schema());
-
-      if (!unwrapped.equals(field.schema())) {
-        builder.addNullableField(field.name(), toFieldType(unwrapped));
-      } else {
-        builder.addField(field.name(), toFieldType(unwrapped));
+      TypeWithNullability nullableType = new 
TypeWithNullability(field.schema());
+      Field beamField = Field.of(field.name(), toFieldType(nullableType));
+      if (field.doc() != null) {
+        beamField = beamField.withDescription(field.doc());
       }
+      builder.addField(beamField);
     }
 
     return builder.build();
   }
 
-  /** Converts AVRO schema to Beam field. */
-  public static Schema.FieldType toFieldType(@Nonnull org.apache.avro.Schema 
avroSchema) {
-    switch (avroSchema.getType()) {
-      case RECORD:
-        return Schema.FieldType.row(toSchema(avroSchema));
+  /** Converts a Beam Schema into an AVRO schema. */
+  public static org.apache.avro.Schema toAvroSchema(Schema beamSchema) {
+    List<org.apache.avro.Schema.Field> fields = Lists.newArrayList();
+    for (Schema.Field field : beamSchema.getFields()) {
+      org.apache.avro.Schema fieldSchema = getFieldSchema(field.getType());
+      org.apache.avro.Schema.Field recordField =
+          new org.apache.avro.Schema.Field(
+              field.getName(), fieldSchema, field.getDescription(), (Object) 
null);
+      fields.add(recordField);
+    }
+    org.apache.avro.Schema avroSchema = 
org.apache.avro.Schema.createRecord(fields);
+    return avroSchema;
+  }
 
-      case ENUM:
-        return Schema.FieldType.STRING;
+  /**
+   * Strict conversion from AVRO to Beam, strict because it doesn't do 
widening or narrowing during
+   * conversion. If Schema is not provided, one is inferred from the AVRO 
schema.
+   */
+  public static Row toBeamRowStrict(GenericRecord record, @Nullable Schema 
schema) {
+    if (schema == null) {
+      schema = toBeamSchema(record.getSchema());
+    }
 
-      case ARRAY:
-        Schema.FieldType elementType = 
toFieldType(avroSchema.getElementType());
-        return Schema.FieldType.array(elementType);
+    Row.Builder builder = Row.withSchema(schema);
+    org.apache.avro.Schema avroSchema = record.getSchema();
 
-      case MAP:
-        return Schema.FieldType.map(
-            Schema.FieldType.STRING, toFieldType(avroSchema.getValueType()));
+    for (Schema.Field field : schema.getFields()) {
+      Object value = record.get(field.getName());
+      org.apache.avro.Schema fieldAvroSchema = 
avroSchema.getField(field.getName()).schema();
 
-      case FIXED:
-        return Schema.FieldType.BYTES;
+      if (value == null) {
+        builder.addValue(null);
+      } else {
+        builder.addValue(convertAvroFieldStrict(value, fieldAvroSchema, 
field.getType()));
+      }
+    }
 
-      case STRING:
-        return Schema.FieldType.STRING;
+    return builder.build();
+  }
 
-      case BYTES:
-        return Schema.FieldType.BYTES;
+  /**
+   * Convert from a Beam Row to an AVRO GenericRecord. If a Schema is not 
provided, one is inferred
+   * from the Beam schema on the orw.
+   */
+  public static GenericRecord toGenericRecord(
+      Row row, @Nullable org.apache.avro.Schema avroSchema) {
+    Schema beamSchema = row.getSchema();
+    // Use the provided AVRO schema if present, otherwise infer an AVRO schema 
from the row
+    // schema.
+    if (avroSchema != null && avroSchema.getFields().size() != 
beamSchema.getFieldCount()) {
+      throw new IllegalArgumentException(
+          "AVRO schema doesn't match row schema. Row schema "
+              + beamSchema
+              + ". AVRO schema + "
+              + avroSchema);
+    }
+    if (avroSchema == null) {
+      avroSchema = toAvroSchema(beamSchema);
+    }
 
-      case INT:
-        return Schema.FieldType.INT32;
+    GenericRecordBuilder builder = new GenericRecordBuilder(avroSchema);
+    for (int i = 0; i < beamSchema.getFieldCount(); ++i) {
+      Schema.Field field = beamSchema.getField(i);
+      builder.set(
+          field.getName(),
+          genericFromBeamField(
+              field.getType(), avroSchema.getField(field.getName()).schema(), 
row.getValue(i)));
+    }
+    return builder.build();
+  }
 
-      case LONG:
-        return Schema.FieldType.INT64;
+  /** Converts AVRO schema to Beam field. */
+  private static Schema.FieldType toFieldType(TypeWithNullability type) {
+    Schema.FieldType fieldType = null;
+    org.apache.avro.Schema avroSchema = type.type;
+
+    LogicalType logicalType = LogicalTypes.fromSchema(avroSchema);
+    if (logicalType != null) {
+      if (logicalType instanceof LogicalTypes.Decimal) {
+        fieldType = FieldType.DECIMAL;
+      } else if (logicalType instanceof LogicalTypes.TimestampMillis) {
+        // TODO: There is a desire to move Beam schema DATETIME to a micros 
representation. When
+        // this is done, this logical type needs to be changed.
+        fieldType = FieldType.DATETIME;
+      }
+    }
+
+    if (fieldType == null) {
+      switch (type.type.getType()) {
+        case RECORD:
+          fieldType = Schema.FieldType.row(toBeamSchema(avroSchema));
+          break;
+
+        case ENUM:
+          fieldType = Schema.FieldType.STRING;
+          break;
+
+        case ARRAY:
+          Schema.FieldType elementType =
+              toFieldType(new 
TypeWithNullability(avroSchema.getElementType()));
+          fieldType = Schema.FieldType.array(elementType);
+          break;
+
+        case MAP:
+          fieldType =
+              Schema.FieldType.map(
+                  Schema.FieldType.STRING,
+                  toFieldType(new 
TypeWithNullability(avroSchema.getValueType())));
+          break;
+
+        case FIXED:
+          fieldType = Schema.FieldType.BYTES;
+          break;
+
+        case STRING:
+          fieldType = Schema.FieldType.STRING;
+          break;
+
+        case BYTES:
+          fieldType = Schema.FieldType.BYTES;
+          break;
+
+        case INT:
+          fieldType = Schema.FieldType.INT32;
+          break;
+
+        case LONG:
+          fieldType = Schema.FieldType.INT64;
+          break;
+
+        case FLOAT:
+          fieldType = Schema.FieldType.FLOAT;
+          break;
+
+        case DOUBLE:
+          fieldType = Schema.FieldType.DOUBLE;
+          break;
+
+        case BOOLEAN:
+          fieldType = Schema.FieldType.BOOLEAN;
+          break;
+
+        case UNION:
+          throw new RuntimeException("Can't convert 'union' to FieldType");
+
+        case NULL:
+          throw new RuntimeException("Can't convert 'null' to FieldType");
+
+        default:
+          throw new AssertionError("Unexpected AVRO Schema.Type: " + 
avroSchema.getType());
+      }
+    }
+    fieldType = fieldType.withNullable(type.nullable);
+    return fieldType;
+  }
+
+  private static org.apache.avro.Schema getFieldSchema(Schema.FieldType 
fieldType) {
+    org.apache.avro.Schema baseType;
+    switch (fieldType.getTypeName()) {
+      case BYTE:
+      case INT16:
+      case INT32:
+        baseType = org.apache.avro.Schema.create(Type.INT);
+        break;
+
+      case INT64:
+        baseType = org.apache.avro.Schema.create(Type.LONG);
+        break;
+
+      case DECIMAL:
+        baseType =
+            LogicalTypes.decimal(Integer.MAX_VALUE)
+                .addToSchema(org.apache.avro.Schema.create(Type.BYTES));
+        break;
 
       case FLOAT:
-        return Schema.FieldType.FLOAT;
+        baseType = org.apache.avro.Schema.create(Type.FLOAT);
+        break;
 
       case DOUBLE:
-        return Schema.FieldType.DOUBLE;
+        baseType = org.apache.avro.Schema.create(Type.DOUBLE);
+        break;
+
+      case STRING:
+        baseType = org.apache.avro.Schema.create(Type.STRING);
+        break;
+
+      case DATETIME:
+        // TODO: There is a desire to move Beam schema DATETIME to a micros 
representation. When
+        // this is done, this logical type needs to be changed.
+        baseType =
+            
LogicalTypes.timestampMillis().addToSchema(org.apache.avro.Schema.create(Type.LONG));
+        break;
 
       case BOOLEAN:
-        return Schema.FieldType.BOOLEAN;
+        baseType = org.apache.avro.Schema.create(Type.BOOLEAN);
+        break;
 
-      case UNION:
-        throw new RuntimeException("Can't convert 'union' to FieldType");
+      case BYTES:
+        baseType = org.apache.avro.Schema.create(Type.BYTES);
+        break;
 
-      case NULL:
-        throw new RuntimeException("Can't convert 'null' to FieldType");
+      case ARRAY:
+        baseType =
+            org.apache.avro.Schema.createArray(
+                getFieldSchema(fieldType.getCollectionElementType()));
+        break;
+
+      case MAP:
+        if (fieldType.getMapKeyType().getTypeName().isStringType()) {
+          // Avro only supports string keys in maps.
+          baseType = 
org.apache.avro.Schema.createMap(getFieldSchema(fieldType.getMapValueType()));
+        } else {
+          throw new IllegalArgumentException("Avro only supports maps with 
string keys");
+        }
+        break;
+
+      case ROW:
+        baseType = toAvroSchema(fieldType.getRowSchema());
+        break;
 
       default:
-        throw new AssertionError("Unexpected AVRO Schema.Type: " + 
avroSchema.getType());
+        throw new IllegalArgumentException("Unexpected type " + fieldType);
     }
+    return fieldType.getNullable() ? ReflectData.makeNullable(baseType) : 
baseType;
   }
 
-  /**
-   * Strict conversion from AVRO to Beam, strict because it doesn't do 
widening or narrowing during
-   * conversion.
-   */
-  public static Row toRowStrict(@Nonnull GenericRecord record, @Nonnull Schema 
schema) {
-    Row.Builder builder = Row.withSchema(schema);
-    org.apache.avro.Schema avroSchema = record.getSchema();
+  private static Object genericFromBeamField(
+      Schema.FieldType fieldType, org.apache.avro.Schema avroSchema, Object 
value) {
+    org.apache.avro.Schema expectedSchema = getFieldSchema(fieldType);
+    switch (fieldType.getTypeName()) {
+      case BYTE:
+      case INT16:
+      case INT32:
+      case INT64:
+      case FLOAT:
+      case DOUBLE:
+      case BOOLEAN:
+        return checkValueType(avroSchema, value, fieldType, expectedSchema);
 
-    for (Schema.Field field : schema.getFields()) {
-      Object value = record.get(field.getName());
-      org.apache.avro.Schema fieldAvroSchema = 
avroSchema.getField(field.getName()).schema();
+      case STRING:
+        return new Utf8((String) value);
 
-      if (value == null) {
-        builder.addValue(null);
-      } else {
-        builder.addValue(convertAvroFieldStrict(value, fieldAvroSchema, 
field.getType()));
-      }
+      case DECIMAL:
+        BigDecimal decimal = (BigDecimal) value;
+        LogicalType logicalType = avroSchema.getLogicalType();
+        ByteBuffer byteBuffer =
+            new Conversions.DecimalConversion().toBytes(decimal, null, 
logicalType);
+        return checkValueType(avroSchema, byteBuffer, fieldType, 
expectedSchema);
+
+      case DATETIME:
+        ReadableInstant instant = (ReadableInstant) value;
+        return checkValueType(avroSchema, instant.getMillis(), fieldType, 
expectedSchema);
+
+      case BYTES:
+        return checkValueType(
+            avroSchema, ByteBuffer.wrap((byte[]) value), fieldType, 
expectedSchema);
+
+      case ARRAY:
+        List array = (List) checkValueType(avroSchema, value, fieldType, 
expectedSchema);
+        List<Object> translatedArray = 
Lists.newArrayListWithExpectedSize(array.size());
+        org.apache.avro.Schema avroArrayType = new 
TypeWithNullability(avroSchema).type;
+
+        for (Object arrayElement : array) {
+          translatedArray.add(
+              genericFromBeamField(
+                  fieldType.getCollectionElementType(),
+                  avroArrayType.getElementType(),
+                  arrayElement));
+        }
+        return checkValueType(avroSchema, translatedArray, fieldType, 
expectedSchema);
+
+      case MAP:
+        ImmutableMap.Builder builder = ImmutableMap.builder();
+        Map<Object, Object> valueMap =
+            (Map<Object, Object>) checkValueType(avroSchema, value, fieldType, 
expectedSchema);
+        org.apache.avro.Schema avroMapType = new 
TypeWithNullability(avroSchema).type;
+
+        for (Map.Entry entry : valueMap.entrySet()) {
+          Utf8 key = new Utf8((String) entry.getKey());
+          builder.put(
+              key,
+              genericFromBeamField(
+                  fieldType.getMapValueType(), avroMapType.getValueType(), 
entry.getValue()));
+        }
+        return checkValueType(avroSchema, builder.build(), fieldType, 
expectedSchema);
+
+      case ROW:
+        return checkValueType(
+            avroSchema, toGenericRecord((Row) value, avroSchema), fieldType, 
expectedSchema);
+
+      default:
+        throw new IllegalArgumentException("Unsupported type " + fieldType);
     }
+  }
 
-    return builder.build();
+  private static Object checkValueType(
+      org.apache.avro.Schema avroSchema,
+      Object o,
+      FieldType fieldType,
+      org.apache.avro.Schema expectedType) {
+    TypeWithNullability typeWithNullability = new 
TypeWithNullability(avroSchema);
+    if (!fieldType.getNullable().equals(typeWithNullability.nullable)) {
+      throw new IllegalArgumentException(
+          "FieldType "
+              + fieldType
+              + " and AVRO schema "
+              + avroSchema
+              + " don't have matching nullability");
+    }
+    return o;
   }
 
   /**
@@ -150,9 +441,21 @@ public static Object convertAvroFieldStrict(
       @Nonnull org.apache.avro.Schema avroSchema,
       @Nonnull Schema.FieldType fieldType) {
 
-    org.apache.avro.Schema unwrapped = unwrapNullableSchema(avroSchema);
+    TypeWithNullability type = new TypeWithNullability(avroSchema);
+    LogicalType logicalType = LogicalTypes.fromSchema(type.type);
+    if (logicalType != null) {
+      if (logicalType instanceof LogicalTypes.Decimal) {
+        ByteBuffer byteBuffer = (ByteBuffer) value;
+        BigDecimal bigDecimal =
+            new Conversions.DecimalConversion()
+                .fromBytes(byteBuffer.duplicate(), type.type, logicalType);
+        return convertDecimal(bigDecimal, fieldType);
+      } else if (logicalType instanceof LogicalTypes.TimestampMillis) {
+        return convertDateTimeStrict((Long) value, fieldType);
+      }
+    }
 
-    switch (unwrapped.getType()) {
+    switch (type.type.getType()) {
       case FIXED:
         return convertFixedStrict((GenericFixed) value, fieldType);
 
@@ -184,11 +487,11 @@ public static Object convertAvroFieldStrict(
         return convertEnumStrict((GenericEnumSymbol) value, fieldType);
 
       case ARRAY:
-        return convertArrayStrict((List<Object>) value, 
unwrapped.getElementType(), fieldType);
+        return convertArrayStrict((List<Object>) value, 
type.type.getElementType(), fieldType);
 
       case MAP:
         return convertMapStrict(
-            (Map<CharSequence, Object>) value, unwrapped.getValueType(), 
fieldType);
+            (Map<CharSequence, Object>) value, type.type.getValueType(), 
fieldType);
 
       case UNION:
         throw new IllegalArgumentException(
@@ -198,50 +501,20 @@ public static Object convertAvroFieldStrict(
         throw new IllegalArgumentException("Can't convert 'null' to 
non-nullable field");
 
       default:
-        throw new AssertionError("Unexpected AVRO Schema.Type: " + 
unwrapped.getType());
+        throw new AssertionError("Unexpected AVRO Schema.Type: " + 
type.type.getType());
     }
   }
 
-  @VisibleForTesting
-  static org.apache.avro.Schema unwrapNullableSchema(org.apache.avro.Schema 
avroSchema) {
-    if (avroSchema.getType() == org.apache.avro.Schema.Type.UNION) {
-      List<org.apache.avro.Schema> types = avroSchema.getTypes();
-
-      // optional fields in AVRO have form of:
-      // {"name": "foo", "type": ["null", "something"]}
-
-      // don't need recursion because nested unions aren't supported in AVRO
-      List<org.apache.avro.Schema> nonNullTypes =
-          types
-              .stream()
-              .filter(x -> x.getType() != org.apache.avro.Schema.Type.NULL)
-              .collect(Collectors.toList());
-
-      if (nonNullTypes.size() == types.size()) {
-        // union without `null`, keep as is
-        return avroSchema;
-      } else if (nonNullTypes.size() > 1) {
-        return org.apache.avro.Schema.createUnion(nonNullTypes);
-      } else if (nonNullTypes.size() == 1) {
-        return nonNullTypes.get(0);
-      } else { // nonNullTypes.size() == 0
-        return avroSchema;
-      }
-    }
-
-    return avroSchema;
-  }
-
   private static Object convertRecordStrict(GenericRecord record, 
Schema.FieldType fieldType) {
     checkTypeName(fieldType.getTypeName(), Schema.TypeName.ROW, "record");
-    return toRowStrict(record, fieldType.getRowSchema());
+    return toBeamRowStrict(record, fieldType.getRowSchema());
   }
 
   private static Object convertBytesStrict(ByteBuffer bb, Schema.FieldType 
fieldType) {
     checkTypeName(fieldType.getTypeName(), Schema.TypeName.BYTES, "bytes");
 
     byte[] bytes = new byte[bb.remaining()];
-    bb.get(bytes);
+    bb.duplicate().get(bytes);
     return bytes;
   }
 
@@ -265,6 +538,16 @@ private static Object convertLongStrict(Long value, 
Schema.FieldType fieldType)
     return value;
   }
 
+  private static Object convertDecimal(BigDecimal value, Schema.FieldType 
fieldType) {
+    checkTypeName(fieldType.getTypeName(), TypeName.DECIMAL, "decimal");
+    return value;
+  }
+
+  private static Object convertDateTimeStrict(Long value, Schema.FieldType 
fieldType) {
+    checkTypeName(fieldType.getTypeName(), TypeName.DATETIME, "dateTime");
+    return new Instant(value);
+  }
+
   private static Object convertFloatStrict(Float value, Schema.FieldType 
fieldType) {
     checkTypeName(fieldType.getTypeName(), Schema.TypeName.FLOAT, "float");
     return value;
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/AvroUtilsTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/AvroUtilsTest.java
index 6e88505b0d36..13512a0c425b 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/AvroUtilsTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/AvroUtilsTest.java
@@ -19,21 +19,39 @@
 
 import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assume.assumeThat;
 
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Lists;
 import com.pholser.junit.quickcheck.From;
 import com.pholser.junit.quickcheck.Property;
 import com.pholser.junit.quickcheck.runner.JUnitQuickcheck;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.nio.ByteBuffer;
 import java.util.List;
 import java.util.function.Function;
+import org.apache.avro.Conversions;
+import org.apache.avro.LogicalType;
+import org.apache.avro.LogicalTypes;
 import org.apache.avro.RandomData;
 import org.apache.avro.Schema.Type;
 import org.apache.avro.generic.GenericRecord;
+import org.apache.avro.generic.GenericRecordBuilder;
+import org.apache.avro.reflect.ReflectData;
+import org.apache.avro.util.Utf8;
 import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
 import org.apache.beam.sdk.schemas.utils.AvroGenerators.RecordSchemaGenerator;
+import org.apache.beam.sdk.schemas.utils.AvroUtils.TypeWithNullability;
+import org.apache.beam.sdk.values.Row;
 import org.hamcrest.BaseMatcher;
 import org.hamcrest.Description;
+import org.joda.time.DateTime;
+import org.joda.time.DateTimeZone;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
@@ -49,15 +67,36 @@
   public void supportsAnyAvroSchema(
       @From(RecordSchemaGenerator.class) org.apache.avro.Schema avroSchema) {
     // not everything is possible to translate
-    assumeThat(avroSchema, 
not(containsField(AvroUtilsTest::hasArrayOrMapOfNullable)));
     assumeThat(avroSchema, not(containsField(AvroUtilsTest::hasNonNullUnion)));
 
-    Schema schema = AvroUtils.toSchema(avroSchema);
+    Schema schema = AvroUtils.toBeamSchema(avroSchema);
     Iterable iterable = new RandomData(avroSchema, 10);
     List<GenericRecord> records = Lists.newArrayList((Iterable<GenericRecord>) 
iterable);
 
     for (GenericRecord record : records) {
-      AvroUtils.toRowStrict(record, schema);
+      AvroUtils.toBeamRowStrict(record, schema);
+    }
+  }
+
+  @Property(trials = 1000)
+  @SuppressWarnings("unchecked")
+  public void avroToBeamRoudTrip(
+      @From(RecordSchemaGenerator.class) org.apache.avro.Schema avroSchema) 
throws IOException {
+    // not everything is possible to translate
+    assumeThat(avroSchema, not(containsField(AvroUtilsTest::hasNonNullUnion)));
+    // roundtrip for enums returns strings because Beam doesn't have enum type
+    assumeThat(avroSchema, not(containsField(x -> x.getType() == Type.ENUM)));
+    // roundtrip for fixed returns bytes because Beam doesn't have FIXED type
+    assumeThat(avroSchema, not(containsField(x -> x.getType() == Type.FIXED)));
+
+    Schema schema = AvroUtils.toBeamSchema(avroSchema);
+    Iterable iterable = new RandomData(avroSchema, 10);
+    List<GenericRecord> records = Lists.newArrayList((Iterable<GenericRecord>) 
iterable);
+
+    for (GenericRecord record : records) {
+      Row row = AvroUtils.toBeamRowStrict(record, schema);
+      GenericRecord out = AvroUtils.toGenericRecord(row, avroSchema);
+      assertEquals(record, out);
     }
   }
 
@@ -67,8 +106,9 @@ public void testUnwrapNullableSchema() {
         org.apache.avro.Schema.createUnion(
             org.apache.avro.Schema.create(Type.NULL), 
org.apache.avro.Schema.create(Type.STRING));
 
-    assertEquals(
-        org.apache.avro.Schema.create(Type.STRING), 
AvroUtils.unwrapNullableSchema(avroSchema));
+    TypeWithNullability typeWithNullability = new 
TypeWithNullability(avroSchema);
+    assertTrue(typeWithNullability.nullable);
+    assertEquals(org.apache.avro.Schema.create(Type.STRING), 
typeWithNullability.type);
   }
 
   @Test
@@ -77,8 +117,9 @@ public void testUnwrapNullableSchemaReordered() {
         org.apache.avro.Schema.createUnion(
             org.apache.avro.Schema.create(Type.STRING), 
org.apache.avro.Schema.create(Type.NULL));
 
-    assertEquals(
-        org.apache.avro.Schema.create(Type.STRING), 
AvroUtils.unwrapNullableSchema(avroSchema));
+    TypeWithNullability typeWithNullability = new 
TypeWithNullability(avroSchema);
+    assertTrue(typeWithNullability.nullable);
+    assertEquals(org.apache.avro.Schema.create(Type.STRING), 
typeWithNullability.type);
   }
 
   @Test
@@ -89,10 +130,231 @@ public void testUnwrapNullableSchemaToUnion() {
             org.apache.avro.Schema.create(Type.LONG),
             org.apache.avro.Schema.create(Type.NULL));
 
+    TypeWithNullability typeWithNullability = new 
TypeWithNullability(avroSchema);
+    assertTrue(typeWithNullability.nullable);
     assertEquals(
         org.apache.avro.Schema.createUnion(
             org.apache.avro.Schema.create(Type.STRING), 
org.apache.avro.Schema.create(Type.LONG)),
-        AvroUtils.unwrapNullableSchema(avroSchema));
+        typeWithNullability.type);
+  }
+
+  private org.apache.avro.Schema getAvroSubSchema() {
+    List<org.apache.avro.Schema.Field> fields = Lists.newArrayList();
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "bool", org.apache.avro.Schema.create(Type.BOOLEAN), "", null));
+    fields.add(
+        new org.apache.avro.Schema.Field("int", 
org.apache.avro.Schema.create(Type.INT), "", null));
+    return org.apache.avro.Schema.createRecord(fields);
+  }
+
+  private org.apache.avro.Schema getAvroSchema() {
+    List<org.apache.avro.Schema.Field> fields = Lists.newArrayList();
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "bool", org.apache.avro.Schema.create(Type.BOOLEAN), "", (Object) 
null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "int", org.apache.avro.Schema.create(Type.INT), "", (Object) 
null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "long", org.apache.avro.Schema.create(Type.LONG), "", (Object) 
null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "float", org.apache.avro.Schema.create(Type.FLOAT), "", (Object) 
null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "double", org.apache.avro.Schema.create(Type.DOUBLE), "", (Object) 
null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "string", org.apache.avro.Schema.create(Type.STRING), "", (Object) 
null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "bytes", org.apache.avro.Schema.create(Type.BYTES), "", (Object) 
null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "decimal",
+            LogicalTypes.decimal(Integer.MAX_VALUE)
+                .addToSchema(org.apache.avro.Schema.create(Type.BYTES)),
+            "",
+            (Object) null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "timestampMillis",
+            
LogicalTypes.timestampMillis().addToSchema(org.apache.avro.Schema.create(Type.LONG)),
+            "",
+            (Object) null));
+    fields.add(new org.apache.avro.Schema.Field("row", getAvroSubSchema(), "", 
(Object) null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "array", org.apache.avro.Schema.createArray(getAvroSubSchema()), 
"", (Object) null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "map", org.apache.avro.Schema.createMap(getAvroSubSchema()), "", 
(Object) null));
+    return org.apache.avro.Schema.createRecord(fields);
+  }
+
+  private Schema getBeamSubSchema() {
+    return new Schema.Builder()
+        .addField(Field.of("bool", FieldType.BOOLEAN))
+        .addField(Field.of("int", FieldType.INT32))
+        .build();
+  }
+
+  private Schema getBeamSchema() {
+    Schema subSchema = getBeamSubSchema();
+    return new Schema.Builder()
+        .addField(Field.of("bool", FieldType.BOOLEAN))
+        .addField(Field.of("int", FieldType.INT32))
+        .addField(Field.of("long", FieldType.INT64))
+        .addField(Field.of("float", FieldType.FLOAT))
+        .addField(Field.of("double", FieldType.DOUBLE))
+        .addField(Field.of("string", FieldType.STRING))
+        .addField(Field.of("bytes", FieldType.BYTES))
+        .addField(Field.of("decimal", FieldType.DECIMAL))
+        .addField(Field.of("timestampMillis", FieldType.DATETIME))
+        .addField(Field.of("row", FieldType.row(subSchema)))
+        .addField(Field.of("array", FieldType.array(FieldType.row(subSchema))))
+        .addField(Field.of("map", FieldType.map(FieldType.STRING, 
FieldType.row(subSchema))))
+        .build();
+  }
+
+  static final byte[] BYTE_ARRAY = new byte[] {1, 2, 3, 4};
+  static final DateTime DATE_TIME =
+      new DateTime().withDate(1979, 03, 14).withTime(1, 2, 3, 
4).withZone(DateTimeZone.UTC);
+  static final BigDecimal BIG_DECIMAL = new BigDecimal(3600);
+
+  private Row getBeamRow() {
+    Row subRow = Row.withSchema(getBeamSubSchema()).addValues(true, 
42).build();
+    return Row.withSchema(getBeamSchema())
+        .addValue(true)
+        .addValue(43)
+        .addValue(44L)
+        .addValue((float) 44.1)
+        .addValue((double) 44.2)
+        .addValue("string")
+        .addValue(BYTE_ARRAY)
+        .addValue(BIG_DECIMAL)
+        .addValue(DATE_TIME)
+        .addValue(subRow)
+        .addValue(ImmutableList.of(subRow, subRow))
+        .addValue(ImmutableMap.of("k1", subRow, "k2", subRow))
+        .build();
+  }
+
+  private GenericRecord getGenericRecord() {
+
+    GenericRecord subRecord =
+        new GenericRecordBuilder(getAvroSubSchema()).set("bool", 
true).set("int", 42).build();
+
+    LogicalType decimalType =
+        LogicalTypes.decimal(Integer.MAX_VALUE)
+            .addToSchema(org.apache.avro.Schema.create(Type.BYTES))
+            .getLogicalType();
+    ByteBuffer encodedDecimal =
+        new Conversions.DecimalConversion().toBytes(BIG_DECIMAL, null, 
decimalType);
+
+    return new GenericRecordBuilder(getAvroSchema())
+        .set("bool", true)
+        .set("int", 43)
+        .set("long", 44L)
+        .set("float", (float) 44.1)
+        .set("double", (double) 44.2)
+        .set("string", new Utf8("string"))
+        .set("bytes", ByteBuffer.wrap(BYTE_ARRAY))
+        .set("decimal", encodedDecimal)
+        .set("timestampMillis", DATE_TIME.getMillis())
+        .set("row", subRecord)
+        .set("array", ImmutableList.of(subRecord, subRecord))
+        .set("map", ImmutableMap.of(new Utf8("k1"), subRecord, new Utf8("k2"), 
subRecord))
+        .build();
+  }
+
+  @Test
+  public void testFromAvroSchema() {
+    assertEquals(getBeamSchema(), AvroUtils.toBeamSchema(getAvroSchema()));
+  }
+
+  @Test
+  public void testFromBeamSchema() {
+    Schema beamSchema = getBeamSchema();
+    org.apache.avro.Schema avroSchema = AvroUtils.toAvroSchema(beamSchema);
+    assertEquals(getAvroSchema(), avroSchema);
+  }
+
+  @Test
+  public void testNullableFieldInAvroSchema() {
+    List<org.apache.avro.Schema.Field> fields = Lists.newArrayList();
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "int", 
ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT)), "", null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "array",
+            org.apache.avro.Schema.createArray(
+                
ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT))),
+            "",
+            null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "map",
+            org.apache.avro.Schema.createMap(
+                
ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT))),
+            "",
+            null));
+    org.apache.avro.Schema avroSchema = 
org.apache.avro.Schema.createRecord(fields);
+
+    Schema expectedSchema =
+        Schema.builder()
+            .addNullableField("int", FieldType.INT32)
+            .addArrayField("array", FieldType.INT32.withNullable(true))
+            .addMapField("map", FieldType.STRING, 
FieldType.INT32.withNullable(true))
+            .build();
+    assertEquals(expectedSchema, AvroUtils.toBeamSchema(avroSchema));
+  }
+
+  @Test
+  public void testNullableFieldsInBeamSchema() {
+    Schema beamSchema =
+        Schema.builder()
+            .addNullableField("int", FieldType.INT32)
+            .addArrayField("array", FieldType.INT32.withNullable(true))
+            .addMapField("map", FieldType.STRING, 
FieldType.INT32.withNullable(true))
+            .build();
+
+    List<org.apache.avro.Schema.Field> fields = Lists.newArrayList();
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "int", 
ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT)), "", null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "array",
+            org.apache.avro.Schema.createArray(
+                
ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT))),
+            "",
+            null));
+    fields.add(
+        new org.apache.avro.Schema.Field(
+            "map",
+            org.apache.avro.Schema.createMap(
+                
ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT))),
+            "",
+            null));
+    org.apache.avro.Schema avroSchema = 
org.apache.avro.Schema.createRecord(fields);
+    assertEquals(avroSchema, AvroUtils.toAvroSchema(beamSchema));
+  }
+
+  @Test
+  public void testBeamRowToGenericRecord() {
+    GenericRecord genericRecord = AvroUtils.toGenericRecord(getBeamRow(), 
null);
+    assertEquals(getAvroSchema(), genericRecord.getSchema());
+    assertEquals(getGenericRecord(), genericRecord);
+  }
+
+  @Test
+  public void testGenericRecordToBeamRow() {
+    Row row = AvroUtils.toBeamRowStrict(getGenericRecord(), null);
+    assertEquals(getBeamRow(), row);
   }
 
   public static ContainsField containsField(Function<org.apache.avro.Schema, 
Boolean> predicate) {
@@ -114,26 +376,6 @@ public static boolean 
hasNonNullUnion(org.apache.avro.Schema schema) {
     return false;
   }
 
-  // doesn't work because Beam doesn't support arrays and maps of nullable 
types
-  public static boolean hasArrayOrMapOfNullable(org.apache.avro.Schema schema) 
{
-
-    if (schema.getType() == Type.ARRAY) {
-      org.apache.avro.Schema elementType = schema.getElementType();
-      if (elementType.getType() == Type.UNION) {
-        return elementType.getTypes().contains(NULL_SCHEMA);
-      }
-    }
-
-    if (schema.getType() == Type.MAP) {
-      org.apache.avro.Schema valueType = schema.getValueType();
-      if (valueType.getType() == Type.UNION) {
-        return valueType.getTypes().contains(NULL_SCHEMA);
-      }
-    }
-
-    return false;
-  }
-
   static class ContainsField extends BaseMatcher<org.apache.avro.Schema> {
 
     private final Function<org.apache.avro.Schema, Boolean> predicate;


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Issue Time Tracking
-------------------

    Worklog Id:     (was: 171703)
    Time Spent: 6h 10m  (was: 6h)

> Provide automatic schema registration for AVROs
> -----------------------------------------------
>
>                 Key: BEAM-4454
>                 URL: https://issues.apache.org/jira/browse/BEAM-4454
>             Project: Beam
>          Issue Type: Sub-task
>          Components: sdk-java-core
>            Reporter: Reuven Lax
>            Assignee: Reuven Lax
>            Priority: Major
>          Time Spent: 6h 10m
>  Remaining Estimate: 0h
>
> Need to make sure this is a compatible change



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to