This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0ba6ba9a382 [SPARK-25050][SQL] Avro: writing complex unions 0ba6ba9a382 is described below commit 0ba6ba9a3829cf63f8917367ea3c066e422ad04f Author: Steven Aerts <steven.ae...@gmail.com> AuthorDate: Thu Feb 23 13:23:30 2023 -0800 [SPARK-25050][SQL] Avro: writing complex unions ### What changes were proposed in this pull request? Spark was able to read complex unions already but not write them. Now it is possible to also write them. If you have a schema with a complex union the following code is now working: ```scala spark .read.format("avro").option("avroSchema", avroSchema).load(path) .write.format("avro").option("avroSchema", avroSchema).save("/tmp/b") ``` While before this patch it would throw `Unsupported Avro UNION type` when writing. Add the capability to write complex unions, next to reading them. Complex unions map to struct types where field names are member0, member1, etc. This is consistent with the behavior in SchemaConverters for reading them and when converting between Avro and Parquet. ### Why are the changes needed? Fixes SPARK-25050, lines up read and write compatibility. ### Does this PR introduce _any_ user-facing change? The behaviour improved of course, this is as far as I could see not impacting any customer facing API's or documentation. ### How was this patch tested? - Added extra unit tests. - Updated existing unit tests for improved behaviour. - Validated manually with an internal corpus of avro files if they now could be read and written without problems. Which was not before this patch. Closes #36506 from steven-aerts/spark-25050. Authored-by: Steven Aerts <steven.ae...@gmail.com> Signed-off-by: Gengliang Wang <gengli...@apache.org> --- .../apache/spark/sql/avro/AvroDeserializer.scala | 5 +- .../org/apache/spark/sql/avro/AvroSerializer.scala | 90 +++++++++++--- .../org/apache/spark/sql/avro/AvroUtils.scala | 5 + .../apache/spark/sql/avro/SchemaConverters.scala | 2 +- .../apache/spark/sql/avro/AvroFunctionsSuite.scala | 30 +++-- .../org/apache/spark/sql/avro/AvroSuite.scala | 135 ++++++++++++++++----- 6 files changed, 207 insertions(+), 60 deletions(-) diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 1192856ae77..aac979cddb2 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -29,7 +29,7 @@ import org.apache.avro.Schema.Type._ import org.apache.avro.generic._ import org.apache.avro.util.Utf8 -import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField} +import org.apache.spark.sql.avro.AvroUtils.{nonNullUnionBranches, toFieldStr, AvroMatchedField} import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters} import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} @@ -289,8 +289,7 @@ private[sql] class AvroDeserializer( updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) case (UNION, _) => - val allTypes = avroType.getTypes.asScala - val nonNullTypes = allTypes.filter(_.getType != NULL) + val nonNullTypes = nonNullUnionBranches(avroType) val nonNullAvroType = Schema.createUnion(nonNullTypes.asJava) if (nonNullTypes.nonEmpty) { if (nonNullTypes.length == 1) { diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index 4a82df6ba0d..c95d731f0de 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -32,7 +32,7 @@ import org.apache.avro.generic.GenericData.Record import org.apache.avro.util.Utf8 import org.apache.spark.internal.Logging -import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField} +import org.apache.spark.sql.avro.AvroUtils.{nonNullUnionBranches, toFieldStr, AvroMatchedField} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -218,6 +218,17 @@ private[sql] class AvroSerializer( val numFields = st.length (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) + case (st: StructType, UNION) => + val unionConvertor = newComplexUnionConverter(st, avroType, catalystPath, avroPath) + val numFields = st.length + (getter, ordinal) => unionConvertor(getter.getStruct(ordinal, numFields)) + + case (DoubleType, UNION) if nonNullUnionTypes(avroType) == Set(FLOAT, DOUBLE) => + (getter, ordinal) => getter.getDouble(ordinal) + + case (LongType, UNION) if nonNullUnionTypes(avroType) == Set(INT, LONG) => + (getter, ordinal) => getter.getLong(ordinal) + case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType => val valueConverter = newConverter( vt, resolveNullableType(avroType.getValueType, valueContainsNull), @@ -287,14 +298,59 @@ private[sql] class AvroSerializer( result } + /** + * Complex unions map to struct types where field names are member0, member1, etc. + * This is consistent with the behavior in [[SchemaConverters]] and when converting between Avro + * and Parquet. + */ + private def newComplexUnionConverter( + catalystStruct: StructType, + unionType: Schema, + catalystPath: Seq[String], + avroPath: Seq[String]): InternalRow => Any = { + val nonNullTypes = nonNullUnionBranches(unionType) + val expectedFieldNames = nonNullTypes.indices.map(i => s"member$i") + val catalystFieldNames = catalystStruct.fieldNames.toSeq + if (positionalFieldMatch) { + if (expectedFieldNames.length != catalystFieldNames.length) { + throw new IncompatibleSchemaException(s"Generic Avro union at ${toFieldStr(avroPath)} " + + s"does not match the SQL schema at ${toFieldStr(catalystPath)}. It expected the " + + s"${expectedFieldNames.length} members but got ${catalystFieldNames.length}") + } + } else { + if (catalystFieldNames != expectedFieldNames) { + throw new IncompatibleSchemaException(s"Generic Avro union at ${toFieldStr(avroPath)} " + + s"does not match the SQL schema at ${toFieldStr(catalystPath)}. It expected the " + + s"following members ${expectedFieldNames.mkString("(", ", ", ")")} but got " + + s"${catalystFieldNames.mkString("(", ", ", ")")}") + } + } + + val unionBranchConverters = nonNullTypes.zip(catalystStruct).map { case (unionBranch, cf) => + newConverter(cf.dataType, unionBranch, catalystPath :+ cf.name, avroPath :+ cf.name) + }.toArray + + val numBranches = catalystStruct.length + row: InternalRow => { + var idx = 0 + var retVal: Any = null + while (idx < numBranches && retVal == null) { + if (!row.isNullAt(idx)) { + retVal = unionBranchConverters(idx).apply(row, idx) + } + idx += 1 + } + retVal + } + } + /** * Resolve a possibly nullable Avro Type. * - * An Avro type is nullable when it is a [[UNION]] of two types: one null type and another - * non-null type. This method will check the nullability of the input Avro type and return the - * non-null type within when it is nullable. Otherwise it will return the input Avro type - * unchanged. It will throw an [[UnsupportedAvroTypeException]] when the input Avro type is an - * unsupported nullable type. + * An Avro type is nullable when it is a [[UNION]] which contains a null type. This method will + * check the nullability of the input Avro type. + * Returns the non-null type within the union when it contains only 1 non-null type. + * Otherwise it will return the input Avro type unchanged. * * It will also log a warning message if the nullability for Avro and catalyst types are * different. @@ -306,20 +362,18 @@ private[sql] class AvroSerializer( } /** - * Check the nullability of the input Avro type and resolve it when it is nullable. The first - * return value is a [[Boolean]] indicating if the input Avro type is nullable. The second - * return value is the possibly resolved type. + * Check the nullability of the input Avro type and resolve it when it is a single nullable type. + * The first return value is a [[Boolean]] indicating if the input Avro type is nullable. + * The second return value is the possibly resolved type otherwise the input Avro type unchanged. */ private def resolveAvroType(avroType: Schema): (Boolean, Schema) = { if (avroType.getType == Type.UNION) { - val fields = avroType.getTypes.asScala - val actualType = fields.filter(_.getType != Type.NULL) - if (fields.length != 2 || actualType.length != 1) { - throw new UnsupportedAvroTypeException( - s"Unsupported Avro UNION type $avroType: Only UNION of a null type and a non-null " + - "type is supported") + val containsNull = avroType.getTypes.asScala.exists(_.getType == Schema.Type.NULL) + nonNullUnionBranches(avroType) match { + case Seq() => (true, Schema.create(Type.NULL)) + case Seq(singleType) => (containsNull, singleType) + case _ => (containsNull, avroType) } - (true, actualType.head) } else { (false, avroType) } @@ -337,4 +391,8 @@ private[sql] class AvroSerializer( "schema will throw runtime exception if there is a record with null value.") } } + + private def nonNullUnionTypes(avroType: Schema): Set[Type] = { + nonNullUnionBranches(avroType).map(_.getType).toSet + } } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 45fa7450e45..e1966bd1041 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -336,4 +336,9 @@ private[sql] object AvroUtils extends Logging { private[avro] def isNullable(avroField: Schema.Field): Boolean = avroField.schema().getType == Schema.Type.UNION && avroField.schema().getTypes.asScala.exists(_.getType == Schema.Type.NULL) + + /** Collect all non null branches of a union in order. */ + private[avro] def nonNullUnionBranches(avroType: Schema): Seq[Schema] = { + avroType.getTypes.asScala.filter(_.getType != Schema.Type.NULL).toSeq + } } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 375f8de3328..f616cfa9b5d 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -127,7 +127,7 @@ object SchemaConverters { case UNION => if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { // In case of a union with null, eliminate it and make a recursive call - val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) + val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema) if (remainingUnionTypes.size == 1) { toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames).copy(nullable = true) } else { diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index 9772033ed3f..7c79162e896 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -21,7 +21,7 @@ import java.io.ByteArrayOutputStream import scala.collection.JavaConverters._ -import org.apache.avro.Schema +import org.apache.avro.{Schema, SchemaBuilder} import org.apache.avro.generic.{GenericDatumWriter, GenericRecord, GenericRecordBuilder} import org.apache.avro.io.EncoderFactory @@ -220,26 +220,36 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { functions.from_avro($"avro", avroTypeStruct)), df) } - test("to_avro with unsupported nullable Avro schema") { + test("to_avro optional union Avro schema") { val df = spark.range(10).select(struct($"id", $"id".cast("string").as("str")).as("struct")) - for (unsupportedAvroType <- Seq("""["null", "int", "long"]""", """["int", "long"]""")) { + for (supportedAvroType <- Seq("""["null", "int", "long"]""", """["int", "long"]""")) { val avroTypeStruct = s""" |{ | "type": "record", | "name": "struct", | "fields": [ - | {"name": "id", "type": $unsupportedAvroType}, + | {"name": "id", "type": $supportedAvroType}, | {"name": "str", "type": ["null", "string"]} | ] |} """.stripMargin - val message = intercept[SparkException] { - df.select(functions.to_avro($"struct", avroTypeStruct).as("avro")).show() - }.getCause.getMessage - assert(message.contains("Only UNION of a null type and a non-null type is supported")) + val avroStructDF = df.select(functions.to_avro($"struct", avroTypeStruct).as("avro")) + checkAnswer(avroStructDF.select( + functions.from_avro($"avro", avroTypeStruct)), df) } } + test("to_avro complex union Avro schema") { + val df = Seq((Some(1), None), (None, Some("a"))).toDF() + .select(struct(struct($"_1".as("member0"), $"_2".as("member1")).as("u")).as("struct")) + val avroTypeStruct = SchemaBuilder.record("struct").fields() + .name("u").`type`().unionOf().intType().and().stringType().endUnion().noDefault() + .endRecord().toString + val avroStructDF = df.select(functions.to_avro($"struct", avroTypeStruct).as("avro")) + checkAnswer(avroStructDF.select( + functions.from_avro($"avro", avroTypeStruct)), df) + } + test("SPARK-39775: Disable validate default values when parsing Avro schemas") { val avroTypeStruct = s""" |{ @@ -255,8 +265,8 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { val df = spark.range(5).select($"id") val structDf = df.select(struct($"id").as("struct")) - val avroStructDF = structDf.select(functions.to_avro('struct, avroTypeStruct).as("avro")) - checkAnswer(avroStructDF.select(functions.from_avro('avro, avroTypeStruct)), structDf) + val avroStructDF = structDf.select(functions.to_avro($"struct", avroTypeStruct).as("avro")) + checkAnswer(avroStructDF.select(functions.from_avro($"avro", avroTypeStruct)), structDf) withTempPath { dir => df.write.format("avro").save(dir.getCanonicalPath) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index debdf9b45cf..d19a11b4546 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -299,21 +299,27 @@ abstract class AvroSuite test("Complex Union Type") { withTempPath { dir => - val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) - val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) - val complexUnionType = Schema.createUnion( - List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) - val fields = Seq( - new Field("field1", complexUnionType, "doc", null.asInstanceOf[AnyVal]), - new Field("field2", complexUnionType, "doc", null.asInstanceOf[AnyVal]), - new Field("field3", complexUnionType, "doc", null.asInstanceOf[AnyVal]), - new Field("field4", complexUnionType, "doc", null.asInstanceOf[AnyVal]) - ).asJava - val schema = Schema.createRecord("name", "docs", "namespace", false) - schema.setFields(fields) + val nativeWriterPath = s"$dir.avro" + val sparkWriterPath = s"$dir/spark" + val fixedSchema = SchemaBuilder.fixed("fixed_name").size(4) + val enumSchema = SchemaBuilder.enumeration("enum_name").symbols("e1", "e2") + val complexUnionType = SchemaBuilder.unionOf() + .intType().and() + .stringType().and() + .`type`(fixedSchema).and() + .`type`(enumSchema).and() + .nullType() + .endUnion() + val schema = SchemaBuilder.record("name").fields() + .name("field1").`type`(complexUnionType).noDefault() + .name("field2").`type`(complexUnionType).noDefault() + .name("field3").`type`(complexUnionType).noDefault() + .name("field4").`type`(complexUnionType).noDefault() + .name("field5").`type`(complexUnionType).noDefault() + .endRecord() val datumWriter = new GenericDatumWriter[GenericRecord](schema) val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) - dataFileWriter.create(schema, new File(s"$dir.avro")) + dataFileWriter.create(schema, new File(nativeWriterPath)) val avroRec = new GenericData.Record(schema) val field1 = 1234 val field2 = "Hope that was not load bearing" @@ -323,15 +329,32 @@ abstract class AvroSuite avroRec.put("field2", field2) avroRec.put("field3", new Fixed(fixedSchema, field3)) avroRec.put("field4", new EnumSymbol(enumSchema, field4)) + avroRec.put("field5", null) dataFileWriter.append(avroRec) dataFileWriter.flush() dataFileWriter.close() - val df = spark.sqlContext.read.format("avro").load(s"$dir.avro") - assertResult(field1)(df.selectExpr("field1.member0").first().get(0)) - assertResult(field2)(df.selectExpr("field2.member1").first().get(0)) - assertResult(field3)(df.selectExpr("field3.member2").first().get(0)) - assertResult(field4)(df.selectExpr("field4.member3").first().get(0)) + val df = spark.sqlContext.read.format("avro").load(nativeWriterPath) + assertResult(Row(field1, null, null, null))(df.selectExpr("field1.*").first()) + assertResult(Row(null, field2, null, null))(df.selectExpr("field2.*").first()) + assertResult(Row(null, null, field3, null))(df.selectExpr("field3.*").first()) + assertResult(Row(null, null, null, field4))(df.selectExpr("field4.*").first()) + assertResult(Row(null, null, null, null))(df.selectExpr("field5.*").first()) + + df.write.format("avro").option("avroSchema", schema.toString).save(sparkWriterPath) + + val df2 = spark.sqlContext.read.format("avro").load(nativeWriterPath) + assertResult(Row(field1, null, null, null))(df2.selectExpr("field1.*").first()) + assertResult(Row(null, field2, null, null))(df2.selectExpr("field2.*").first()) + assertResult(Row(null, null, field3, null))(df2.selectExpr("field3.*").first()) + assertResult(Row(null, null, null, field4))(df2.selectExpr("field4.*").first()) + assertResult(Row(null, null, null, null))(df2.selectExpr("field5.*").first()) + + val reader = openDatumReader(new File(sparkWriterPath)) + assert(reader.hasNext) + assertResult(avroRec)(reader.next()) + assert(!reader.hasNext) + reader.close() } } @@ -1143,32 +1166,81 @@ abstract class AvroSuite } } - test("unsupported nullable avro type") { + test("int/long double/float conversion") { val catalystSchema = StructType(Seq( - StructField("Age", IntegerType, nullable = false), - StructField("Name", StringType, nullable = false))) + StructField("Age", LongType), + StructField("Length", DoubleType), + StructField("Name", StringType))) - for (unsupportedAvroType <- Seq("""["null", "int", "long"]""", """["int", "long"]""")) { + for (optionalNull <- Seq(""""null",""", "")) { val avroSchema = s""" |{ | "type" : "record", | "name" : "test_schema", | "fields" : [ - | {"name": "Age", "type": $unsupportedAvroType}, + | {"name": "Age", "type": [$optionalNull "int", "long"]}, + | {"name": "Length", "type": [$optionalNull "float", "double"]}, | {"name": "Name", "type": ["null", "string"]} | ] |} """.stripMargin val df = spark.createDataFrame( - spark.sparkContext.parallelize(Seq(Row(2, "Aurora"))), catalystSchema) + spark.sparkContext.parallelize(Seq(Row(2L, 1.8D, "Aurora"), Row(1L, 0.9D, null))), + catalystSchema) + + withTempPath { tempDir => + df.write.format("avro").option("avroSchema", avroSchema).save(tempDir.getPath) + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .load(tempDir.getPath), + df) + } + } + } + + test("non-matching complex union types") { + val catalystSchema = new StructType().add("Union", new StructType() + .add("member0", IntegerType) + .add("member1", new StructType().add("f1", StringType, nullable = false)) + ) + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(Row(1, null)))), catalystSchema) + + val recordS = SchemaBuilder.record("r").fields().requiredString("f1").endRecord() + val intS = Schema.create(Schema.Type.INT) + val nullS = Schema.create(Schema.Type.NULL) + for ((unionTypes, compatible) <- Seq( + (Seq(nullS, intS, recordS), true), + (Seq(intS, nullS, recordS), true), + (Seq(intS, recordS, nullS), true), + (Seq(intS, recordS), true), + (Seq(nullS, recordS, intS), false), + (Seq(nullS, recordS), false), + (Seq(nullS, SchemaBuilder.record("r").fields().requiredString("f2").endRecord()), false) + )) { + val avroSchema = SchemaBuilder.record("test_schema").fields() + .name("union").`type`(Schema.createUnion(unionTypes: _*)).noDefault() + .endRecord().toString() withTempPath { tempDir => - val message = intercept[SparkException] { + if (!compatible) { + intercept[SparkException] { + df.write.format("avro").option("avroSchema", avroSchema).save(tempDir.getPath) + } + } else { df.write.format("avro").option("avroSchema", avroSchema).save(tempDir.getPath) - }.getMessage - assert(message.contains("Only UNION of a null type and a non-null type is supported")) + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .load(tempDir.getPath), + df) + } } } } @@ -2104,12 +2176,15 @@ abstract class AvroSuite } private def checkMetaData(path: java.io.File, key: String, expectedValue: String): Unit = { + val value = openDatumReader(path).asInstanceOf[DataFileReader[_]].getMetaString(key) + assert(value === expectedValue) + } + + private def openDatumReader(path: File): org.apache.avro.file.FileReader[GenericRecord] = { val avroFiles = path.listFiles() .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) assert(avroFiles.length === 1) - val reader = DataFileReader.openReader(avroFiles(0), new GenericDatumReader[GenericRecord]()) - val value = reader.asInstanceOf[DataFileReader[_]].getMetaString(key) - assert(value === expectedValue) + DataFileReader.openReader(avroFiles(0), new GenericDatumReader[GenericRecord]()) } test("SPARK-31327: Write Spark version into Avro file metadata") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org