Github user dbtsai commented on a diff in the pull request: https://github.com/apache/spark/pull/21847#discussion_r206358703 --- Diff: external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala --- @@ -165,16 +182,118 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: result } - private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = { - if (nullable) { + // Resolve an Avro union against a supplied DataType, i.e. a LongType compared against + // a ["null", "long"] should return a schema of type Schema.Type.LONG + // This function also handles resolving a DataType against unions of 2 or more types, i.e. + // an IntType resolves against a ["int", "long", "null"] will correctly return a schema of + // type Schema.Type.LONG + private def resolveUnionType(avroType: Schema, catalystType: DataType, + nullable: Boolean): Schema = { + if (avroType.getType == Type.UNION) { // avro uses union to represent nullable type. - val fields = avroType.getTypes.asScala - assert(fields.length == 2) - val actualType = fields.filter(_.getType != NULL) - assert(actualType.length == 1) + val fieldTypes = avroType.getTypes.asScala + + // If we're nullable, we need to have at least two types. Cases with more than two types + // are captured in test("read read-write, read-write w/ schema, read") w/ test.avro input + if (nullable && fieldTypes.length < 2) { + throw new IncompatibleSchemaException( + s"Cannot resolve nullable ${catalystType} against union type ${avroType}") + } + + val actualType = catalystType match { + case NullType => fieldTypes.filter(_.getType == Type.NULL) + case BooleanType => fieldTypes.filter(_.getType == Type.BOOLEAN) + case ByteType => fieldTypes.filter(_.getType == Type.INT) + case BinaryType => + val at = fieldTypes.filter(x => x.getType == Type.BYTES || x.getType == Type.FIXED) + if (at.length > 1) { + throw new IncompatibleSchemaException( + s"Cannot resolve schema of ${catalystType} against union ${avroType.toString}") + } else { + at + } + case ShortType | IntegerType => fieldTypes.filter(_.getType == Type.INT) + case LongType => fieldTypes.filter(_.getType == Type.LONG) + case FloatType => fieldTypes.filter(_.getType == Type.FLOAT) + case DoubleType => fieldTypes.filter(_.getType == Type.DOUBLE) + case d: DecimalType => fieldTypes.filter(_.getType == Type.STRING) + case StringType => fieldTypes + .filter(x => x.getType == Type.STRING || x.getType == Type.ENUM) + case DateType => fieldTypes.filter(x => x.getType == Type.INT || x.getType == Type.LONG) + case TimestampType => fieldTypes.filter(_.getType == Type.LONG) + case ArrayType(et, containsNull) => + // Find array that matches the element type specified + fieldTypes.filter(x => x.getType == Type.ARRAY + && typeMatchesSchema(et, x.getElementType)) + case st: StructType => // Find the matching record! + val recordTypes = fieldTypes.filter(x => x.getType == Type.RECORD) + if (recordTypes.length > 1) { + throw new IncompatibleSchemaException( + "Unions of multiple record types are NOT supported with user-specified schema") + } + recordTypes + case MapType(kt, vt, valueContainsNull) => + // Find the map that matches the value type. Maps in Avro are always key type string + fieldTypes.filter(x => x.getType == Type.MAP && typeMatchesSchema(vt, x.getValueType)) + case other => + throw new IncompatibleSchemaException(s"Unexpected type: $other") + } + + if (actualType.length != 1) { + throw new IncompatibleSchemaException( + s"Failed to resolve ${catalystType} against ambiguous schema ${avroType}") + } actualType.head } else { avroType } } + + // Given a Schema and a DataType, do they match? + private def typeMatchesSchema(catalystType: DataType, avroSchema: Schema): Boolean = { + if (catalystType.isInstanceOf[StructType]) { + val avroFields = resolveUnionType(avroSchema, catalystType, + avroSchema.getType == Type.UNION) + .getFields + if (avroFields.size() == catalystType.asInstanceOf[StructType].length) { + catalystType.asInstanceOf[StructType].zip(avroFields.asScala).forall { + case (f1, f2) => typeMatchesSchema(f1.dataType, f2.schema) + } + } else { + false + } + } else { + val isTypeCompatible = (a: Schema, b: DataType, c: Type) => + resolveUnionType(a, b, a.getType == Type.UNION).getType == c + + catalystType match { + case ByteType | ShortType | IntegerType => + isTypeCompatible(avroSchema, catalystType, Type.INT) + case BooleanType => isTypeCompatible(avroSchema, catalystType, Type.BOOLEAN) + case BinaryType => isTypeCompatible(avroSchema, catalystType, Type.BYTES) + case LongType | TimestampType => isTypeCompatible(avroSchema, catalystType, Type.LONG) + case FloatType => isTypeCompatible(avroSchema, catalystType, Type.FLOAT) + case DoubleType => isTypeCompatible(avroSchema, catalystType, Type.DOUBLE) + case d: DecimalType => + // newConverter always returns a string representation for DecimalType, so we honor + // that here, since we don't yet support Avro's logical types + isTypeCompatible(avroSchema, catalystType, Type.STRING) + case StringType => isTypeCompatible(avroSchema, catalystType, Type.STRING) || + isTypeCompatible(avroSchema, catalystType, Type.ENUM) + case DateType => isTypeCompatible(avroSchema, catalystType, Type.INT) || + isTypeCompatible(avroSchema, catalystType, Type.LONG) + case ArrayType(et, containsNull) => + isTypeCompatible(avroSchema, catalystType, Type.ARRAY) && + typeMatchesSchema(et, + resolveUnionType(avroSchema, catalystType, avroSchema.getType == Type.UNION) + .getElementType) + case MapType(kt, vt, valueContainsNull) => + isTypeCompatible(avroSchema, catalystType, Type.MAP) && + typeMatchesSchema(vt, + resolveUnionType(avroSchema, catalystType, avroSchema.getType == Type.UNION) + .getValueType) + } + } + --- End diff -- remove the extra line
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org