Github user twalthr commented on a diff in the pull request: https://github.com/apache/flink/pull/6218#discussion_r199759148 --- Diff: flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroRowDeserializationSchema.java --- @@ -17,154 +17,338 @@ package org.apache.flink.formats.avro; +import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.serialization.AbstractDeserializationSchema; +import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.formats.avro.typeutils.AvroRecordClassConverter; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.MapTypeInfo; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.formats.avro.typeutils.AvroSchemaConverter; import org.apache.flink.formats.avro.utils.MutableByteArrayInputStream; import org.apache.flink.types.Row; import org.apache.flink.util.Preconditions; +import org.apache.avro.LogicalTypes; import org.apache.avro.Schema; -import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericFixed; +import org.apache.avro.generic.IndexedRecord; import org.apache.avro.io.DatumReader; import org.apache.avro.io.Decoder; import org.apache.avro.io.DecoderFactory; import org.apache.avro.specific.SpecificData; import org.apache.avro.specific.SpecificDatumReader; import org.apache.avro.specific.SpecificRecord; -import org.apache.avro.specific.SpecificRecordBase; -import org.apache.avro.util.Utf8; +import org.joda.time.DateTime; +import org.joda.time.DateTimeFieldType; +import org.joda.time.LocalDate; +import org.joda.time.LocalTime; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.lang.reflect.Array; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.sql.Date; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.TimeZone; /** - * Deserialization schema from Avro bytes over {@link SpecificRecord} to {@link Row}. + * Deserialization schema from Avro bytes to {@link Row}. * - * <p>Deserializes the <code>byte[]</code> messages into (nested) Flink Rows. + * <p>Deserializes the <code>byte[]</code> messages into (nested) Flink rows. It converts Avro types + * into types that are compatible with Flink's Table & SQL API. * - * {@link Utf8} is converted to regular Java Strings. + * <p>Projects with Avro records containing logical date/time types need to add a JodaTime + * dependency. + * + * <p>Note: Changes in this class need to be kept in sync with the corresponding runtime + * class {@link AvroRowSerializationSchema} and schema converter {@link AvroSchemaConverter}. */ +@PublicEvolving public class AvroRowDeserializationSchema extends AbstractDeserializationSchema<Row> { /** - * Avro record class. + * Used for time conversions into SQL types. + */ + private static final TimeZone LOCAL_TZ = TimeZone.getDefault(); + + /** + * Avro record class for deserialization. Might be null if record class is not available. */ private Class<? extends SpecificRecord> recordClazz; /** - * Schema for deterministic field order. + * Schema string for deserialization. + */ + private String schemaString; + + /** + * Avro serialization schema. */ private transient Schema schema; /** - * Reader that deserializes byte array into a record. + * Type information describing the result type. */ - private transient DatumReader<SpecificRecord> datumReader; + private transient TypeInformation<Row> typeInfo; /** - * Input stream to read message from. + * Record to deserialize byte array. */ - private transient MutableByteArrayInputStream inputStream; + private transient IndexedRecord record; /** - * Avro decoder that decodes binary data. + * Reader that deserializes byte array into a record. */ - private transient Decoder decoder; + private transient DatumReader<IndexedRecord> datumReader; /** - * Record to deserialize byte array to. + * Input stream to read message from. */ - private SpecificRecord record; + private transient MutableByteArrayInputStream inputStream; /** - * Type information describing the result type. + * Avro decoder that decodes binary data. */ - private transient TypeInformation<Row> typeInfo; + private transient Decoder decoder; /** - * Creates a Avro deserialization schema for the given record. + * Creates a Avro deserialization schema for the given specific record class. Having the + * concrete Avro record class might improve performance. * * @param recordClazz Avro record class used to deserialize Avro's record to Flink's row */ - public AvroRowDeserializationSchema(Class<? extends SpecificRecordBase> recordClazz) { + public AvroRowDeserializationSchema(Class<? extends SpecificRecord> recordClazz) { Preconditions.checkNotNull(recordClazz, "Avro record class must not be null."); this.recordClazz = recordClazz; - this.schema = SpecificData.get().getSchema(recordClazz); - this.datumReader = new SpecificDatumReader<>(schema); - this.record = (SpecificRecord) SpecificData.newInstance(recordClazz, schema); - this.inputStream = new MutableByteArrayInputStream(); - this.decoder = DecoderFactory.get().binaryDecoder(inputStream, null); - this.typeInfo = AvroRecordClassConverter.convert(recordClazz); + schema = SpecificData.get().getSchema(recordClazz); + typeInfo = AvroSchemaConverter.convert(recordClazz); + schemaString = schema.toString(); + record = (SpecificRecord) SpecificData.newInstance(recordClazz, schema); + datumReader = new SpecificDatumReader<>(schema); + inputStream = new MutableByteArrayInputStream(); + decoder = DecoderFactory.get().binaryDecoder(inputStream, null); + } + + /** + * Creates a Avro deserialization schema for the given Avro schema string. + * + * @param avroSchemaString Avro schema string to deserialize Avro's record to Flink's row + */ + public AvroRowDeserializationSchema(String avroSchemaString) { + Preconditions.checkNotNull(avroSchemaString, "Avro schema must not be null."); + recordClazz = null; + typeInfo = AvroSchemaConverter.convert(avroSchemaString); + schemaString = avroSchemaString; + schema = new Schema.Parser().parse(avroSchemaString); + record = new GenericData.Record(schema); + datumReader = new GenericDatumReader<>(schema); + inputStream = new MutableByteArrayInputStream(); + decoder = DecoderFactory.get().binaryDecoder(inputStream, null); + // check for a schema that describes a record + if (!(typeInfo instanceof RowTypeInfo)) { + throw new IllegalArgumentException("Row type information expected."); + } } @Override public Row deserialize(byte[] message) throws IOException { - // read record try { inputStream.setBuffer(message); - this.record = datumReader.read(record, decoder); - } catch (IOException e) { - throw new RuntimeException("Failed to deserialize Row.", e); + final IndexedRecord read = datumReader.read(record, decoder); + return convertRecord(schema, (RowTypeInfo) typeInfo, read); + } catch (Exception e) { + throw new IOException("Failed to deserialize Avro record.", e); } - - // convert to row - final Object row = convertToRow(schema, record); - return (Row) row; - } - - private void writeObject(ObjectOutputStream oos) throws IOException { - oos.writeObject(recordClazz); - } - - @SuppressWarnings("unchecked") - private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { - this.recordClazz = (Class<? extends SpecificRecord>) ois.readObject(); - this.schema = SpecificData.get().getSchema(recordClazz); - this.datumReader = new SpecificDatumReader<>(schema); - this.record = (SpecificRecord) SpecificData.newInstance(recordClazz, schema); - this.inputStream = new MutableByteArrayInputStream(); - this.decoder = DecoderFactory.get().binaryDecoder(inputStream, null); } @Override public TypeInformation<Row> getProducedType() { return typeInfo; } - /** - * Converts a (nested) Avro {@link SpecificRecord} into Flink's Row type. - * Avro's {@link Utf8} fields are converted into regular Java strings. - */ - private static Object convertToRow(Schema schema, Object recordObj) { - if (recordObj instanceof GenericRecord) { - // records can be wrapped in a union - if (schema.getType() == Schema.Type.UNION) { + // -------------------------------------------------------------------------------------------- + + private Row convertRecord(Schema schema, RowTypeInfo typeInfo, IndexedRecord record) { + final List<Schema.Field> fields = schema.getFields(); + final TypeInformation<?>[] fieldInfo = typeInfo.getFieldTypes(); + final int length = fields.size(); + final Row row = new Row(length); + for (int i = 0; i < length; i++) { + final Schema.Field field = fields.get(i); + row.setField(i, convert(field.schema(), fieldInfo[i], record.get(i))); + } + return row; + } + + private Object convert(Schema schema, TypeInformation<?> info, Object object) { + // we perform the conversion based on schema information but enriched with pre-computed + // type information where useful (i.e., for arrays) + + if (object == null) { + return null; + } + switch (schema.getType()) { + case RECORD: + if (object instanceof IndexedRecord) { + return convertRecord(schema, (RowTypeInfo) info, (IndexedRecord) object); + } + throw new IllegalStateException("IndexedRecord expected but was: " + object.getClass()); + case ENUM: + case STRING: + return object.toString(); + case ARRAY: + if (info instanceof BasicArrayTypeInfo) { + final BasicArrayTypeInfo<?, ?> bati = (BasicArrayTypeInfo<?, ?>) info; + final TypeInformation<?> elementInfo = bati.getComponentInfo(); + return convertObjectArray(schema.getElementType(), elementInfo, object); + } else { + final ObjectArrayTypeInfo<?, ?> oati = (ObjectArrayTypeInfo<?, ?>) info; + final TypeInformation<?> elementInfo = oati.getComponentInfo(); + return convertObjectArray(schema.getElementType(), elementInfo, object); + } + case MAP: + final MapTypeInfo<?, ?> mti = (MapTypeInfo<?, ?>) info; + final Map<String, Object> convertedMap = new HashMap<>(); + final Map<?, ?> map = (Map<?, ?>) object; + for (Map.Entry<?, ?> entry : map.entrySet()) { + convertedMap.put( + entry.getKey().toString(), + convert(schema.getValueType(), mti.getValueTypeInfo(), entry.getValue())); + } + return convertedMap; + case UNION: final List<Schema> types = schema.getTypes(); - if (types.size() == 2 && types.get(0).getType() == Schema.Type.NULL && types.get(1).getType() == Schema.Type.RECORD) { - schema = types.get(1); + final int size = types.size(); + final Schema actualSchema; + if (size == 2 && types.get(0).getType() == Schema.Type.NULL) { + return convert(types.get(1), info, object); + } else if (size == 2 && types.get(1).getType() == Schema.Type.NULL) { + return convert(types.get(0), info, object); + } else if (size == 1) { + return convert(types.get(0), info, object); + } else { + // generic type + return object; + } + case FIXED: + final byte[] fixedBytes = ((GenericFixed) object).bytes(); + if (info == Types.BIG_DEC) { + return convertDecimal(schema, fixedBytes); } - else { - throw new RuntimeException("Currently we only support schemas of the following form: UNION[null, RECORD]. Given: " + schema); + return fixedBytes; + case BYTES: + final ByteBuffer bb = (ByteBuffer) object; --- End diff -- We should only add checks in runtime code if really necessary. IMHO it does not matter if a cast exception or a illegal state exception is thrown.
---