Github user StephanEwen commented on a diff in the pull request: https://github.com/apache/flink/pull/5995#discussion_r188340240 --- Diff: flink-formats/flink-avro/src/main/java/org/apache/flink/formats/avro/AvroDeserializationSchema.java --- @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.formats.avro; + +import org.apache.flink.api.common.serialization.DeserializationSchema; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.formats.avro.typeutils.AvroTypeInfo; +import org.apache.flink.formats.avro.typeutils.GenericRecordAvroTypeInfo; +import org.apache.flink.formats.avro.utils.MutableByteArrayInputStream; +import org.apache.flink.util.Preconditions; + +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.Decoder; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.reflect.ReflectDatumReader; +import org.apache.avro.specific.SpecificData; +import org.apache.avro.specific.SpecificDatumReader; +import org.apache.avro.specific.SpecificRecord; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.io.ObjectInputStream; + +/** + * Deserialization schema that deserializes from Avro binary format. + * + * @param <T> type of record it produces + */ +public class AvroDeserializationSchema<T> implements DeserializationSchema<T> { + + /** + * Class to deserialize to. + */ + private Class<T> recordClazz; + + private String schemaString = null; + + /** + * Reader that deserializes byte array into a record. + */ + private transient GenericDatumReader<T> datumReader; + + /** + * Input stream to read message from. + */ + private transient MutableByteArrayInputStream inputStream; + + /** + * Avro decoder that decodes binary data. + */ + private transient Decoder decoder; + + /** + * Avro schema for the reader. + */ + private transient Schema reader; + + /** + * Creates a Avro deserialization schema. + * + * @param recordClazz class to which deserialize. Should be one of: + * {@link org.apache.avro.specific.SpecificRecord}, + * {@link org.apache.avro.generic.GenericRecord}. + * @param reader reader's Avro schema. Should be provided if recordClazz is + * {@link GenericRecord} + */ + AvroDeserializationSchema(Class<T> recordClazz, @Nullable Schema reader) { + Preconditions.checkNotNull(recordClazz, "Avro record class must not be null."); + this.recordClazz = recordClazz; + this.inputStream = new MutableByteArrayInputStream(); + this.decoder = DecoderFactory.get().binaryDecoder(inputStream, null); + this.reader = reader; + if (reader != null) { + this.schemaString = reader.toString(); + } + } + + /** + * Creates {@link AvroDeserializationSchema} that produces {@link GenericRecord} using provided schema. + * + * @param schema schema of produced records + * @return deserialized record in form of {@link GenericRecord} + */ + public static AvroDeserializationSchema<GenericRecord> forGeneric(Schema schema) { + return new AvroDeserializationSchema<>(GenericRecord.class, schema); + } + + /** + * Creates {@link AvroDeserializationSchema} that produces classes that were generated from avro schema. + * + * @param tClass class of record to be produced + * @return deserialized record + */ + public static <T extends SpecificRecord> AvroDeserializationSchema<T> forSpecific(Class<T> tClass) { + return new AvroDeserializationSchema<>(tClass, null); + } + + GenericDatumReader<T> getDatumReader() { + if (datumReader != null) { + return datumReader; + } + + if (SpecificRecord.class.isAssignableFrom(recordClazz)) { + this.datumReader = new SpecificDatumReader<>(); + } else if (GenericRecord.class.isAssignableFrom(recordClazz)) { + this.datumReader = new GenericDatumReader<>(); + } else { + this.datumReader = new ReflectDatumReader<>(); + } + + return datumReader; + } + + Schema getReaderSchema() { + if (reader != null) { + return reader; + } + + if (SpecificRecord.class.isAssignableFrom(recordClazz)) { + this.reader = SpecificData.get().getSchema(recordClazz); + } else if (GenericRecord.class.isAssignableFrom(recordClazz)) { + throw new IllegalStateException( + "Cannot infer schema for generic record. Please pass explicit schema in the ctor."); + } else { + this.reader = ReflectData.get().getSchema(recordClazz); + } + + return reader; + } + + MutableByteArrayInputStream getInputStream() { + return inputStream; + } + + Decoder getDecoder() { + return decoder; + } + + @Override + public T deserialize(byte[] message) { + // read record + try { + inputStream.setBuffer(message); + Schema readerSchema = getReaderSchema(); + GenericDatumReader<T> datumReader = getDatumReader(); + + datumReader.setSchema(readerSchema); + + return datumReader.read(null, decoder); + } catch (Exception e) { + throw new RuntimeException("Failed to deserialize message.", e); + } + } + + private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { + ois.defaultReadObject(); + this.inputStream = new MutableByteArrayInputStream(); + this.decoder = DecoderFactory.get().binaryDecoder(inputStream, null); + if (schemaString != null) { + this.reader = new Schema.Parser().parse(schemaString); + } + } + + @Override + public boolean isEndOfStream(T nextElement) { + return false; + } + + @Override + @SuppressWarnings("unchecked") + public TypeInformation<T> getProducedType() { + if (SpecificRecord.class.isAssignableFrom(recordClazz)) { + return new AvroTypeInfo(recordClazz, false); --- End diff -- Avoid use of raw types. Generic variant should be possible here: `return new AvroTypeInfo<>(recordClazz,.asSubclass(SpecificRecord.class), false);`
---