This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0ba6ba9a382 [SPARK-25050][SQL] Avro: writing complex unions
0ba6ba9a382 is described below

commit 0ba6ba9a3829cf63f8917367ea3c066e422ad04f
Author: Steven Aerts <steven.ae...@gmail.com>
AuthorDate: Thu Feb 23 13:23:30 2023 -0800

    [SPARK-25050][SQL] Avro: writing complex unions
    
    ### What changes were proposed in this pull request?
    
    Spark was able to read complex unions already but not write them.
    Now it is possible to also write them.  If you have a schema with a complex 
union the following code is now working:
    
    ```scala
    spark
      .read.format("avro").option("avroSchema", avroSchema).load(path)
      .write.format("avro").option("avroSchema", avroSchema).save("/tmp/b")
    ```
    While before this patch it would throw `Unsupported Avro UNION type` when 
writing.
    
    Add the capability to write complex unions, next to reading them.
    Complex unions map to struct types where field names are member0, member1, 
etc.
    This is consistent with the behavior in SchemaConverters for reading them
    and when converting between Avro and Parquet.
    
    ### Why are the changes needed?
    Fixes SPARK-25050, lines up read and write compatibility.
    
    ### Does this PR introduce _any_ user-facing change?
    The behaviour improved of course, this is as far as I could see not 
impacting any customer facing API's or documentation.
    
    ### How was this patch tested?
    - Added extra unit tests.
    - Updated existing unit tests for improved behaviour.
    - Validated manually with an internal corpus of avro files if they now 
could be read and written without problems.  Which was not before this patch.
    
    Closes #36506 from steven-aerts/spark-25050.
    
    Authored-by: Steven Aerts <steven.ae...@gmail.com>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../apache/spark/sql/avro/AvroDeserializer.scala   |   5 +-
 .../org/apache/spark/sql/avro/AvroSerializer.scala |  90 +++++++++++---
 .../org/apache/spark/sql/avro/AvroUtils.scala      |   5 +
 .../apache/spark/sql/avro/SchemaConverters.scala   |   2 +-
 .../apache/spark/sql/avro/AvroFunctionsSuite.scala |  30 +++--
 .../org/apache/spark/sql/avro/AvroSuite.scala      | 135 ++++++++++++++++-----
 6 files changed, 207 insertions(+), 60 deletions(-)

diff --git 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index 1192856ae77..aac979cddb2 100644
--- 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
+++ 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -29,7 +29,7 @@ import org.apache.avro.Schema.Type._
 import org.apache.avro.generic._
 import org.apache.avro.util.Utf8
 
-import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField}
+import org.apache.spark.sql.avro.AvroUtils.{nonNullUnionBranches, toFieldStr, 
AvroMatchedField}
 import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters}
 import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, 
UnsafeArrayData}
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
DateTimeUtils, GenericArrayData}
@@ -289,8 +289,7 @@ private[sql] class AvroDeserializer(
           updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))
 
       case (UNION, _) =>
-        val allTypes = avroType.getTypes.asScala
-        val nonNullTypes = allTypes.filter(_.getType != NULL)
+        val nonNullTypes = nonNullUnionBranches(avroType)
         val nonNullAvroType = Schema.createUnion(nonNullTypes.asJava)
         if (nonNullTypes.nonEmpty) {
           if (nonNullTypes.length == 1) {
diff --git 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
index 4a82df6ba0d..c95d731f0de 100644
--- 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
+++ 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -32,7 +32,7 @@ import org.apache.avro.generic.GenericData.Record
 import org.apache.avro.util.Utf8
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField}
+import org.apache.spark.sql.avro.AvroUtils.{nonNullUnionBranches, toFieldStr, 
AvroMatchedField}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, 
SpecificInternalRow}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -218,6 +218,17 @@ private[sql] class AvroSerializer(
         val numFields = st.length
         (getter, ordinal) => structConverter(getter.getStruct(ordinal, 
numFields))
 
+      case (st: StructType, UNION) =>
+        val unionConvertor = newComplexUnionConverter(st, avroType, 
catalystPath, avroPath)
+        val numFields = st.length
+        (getter, ordinal) => unionConvertor(getter.getStruct(ordinal, 
numFields))
+
+      case (DoubleType, UNION) if nonNullUnionTypes(avroType) == Set(FLOAT, 
DOUBLE) =>
+        (getter, ordinal) => getter.getDouble(ordinal)
+
+      case (LongType, UNION) if nonNullUnionTypes(avroType) == Set(INT, LONG) 
=>
+        (getter, ordinal) => getter.getLong(ordinal)
+
       case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
         val valueConverter = newConverter(
           vt, resolveNullableType(avroType.getValueType, valueContainsNull),
@@ -287,14 +298,59 @@ private[sql] class AvroSerializer(
       result
   }
 
+  /**
+   * Complex unions map to struct types where field names are member0, 
member1, etc.
+   * This is consistent with the behavior in [[SchemaConverters]] and when 
converting between Avro
+   * and Parquet.
+   */
+  private def newComplexUnionConverter(
+      catalystStruct: StructType,
+      unionType: Schema,
+      catalystPath: Seq[String],
+      avroPath: Seq[String]): InternalRow => Any = {
+    val nonNullTypes = nonNullUnionBranches(unionType)
+    val expectedFieldNames = nonNullTypes.indices.map(i => s"member$i")
+    val catalystFieldNames = catalystStruct.fieldNames.toSeq
+    if (positionalFieldMatch) {
+      if (expectedFieldNames.length != catalystFieldNames.length) {
+        throw new IncompatibleSchemaException(s"Generic Avro union at 
${toFieldStr(avroPath)} " +
+          s"does not match the SQL schema at ${toFieldStr(catalystPath)}.  It 
expected the " +
+          s"${expectedFieldNames.length} members but got 
${catalystFieldNames.length}")
+      }
+    } else {
+      if (catalystFieldNames != expectedFieldNames) {
+        throw new IncompatibleSchemaException(s"Generic Avro union at 
${toFieldStr(avroPath)} " +
+          s"does not match the SQL schema at ${toFieldStr(catalystPath)}.  It 
expected the " +
+          s"following members ${expectedFieldNames.mkString("(", ", ", ")")} 
but got " +
+          s"${catalystFieldNames.mkString("(", ", ", ")")}")
+      }
+    }
+
+    val unionBranchConverters = nonNullTypes.zip(catalystStruct).map { case 
(unionBranch, cf) =>
+      newConverter(cf.dataType, unionBranch, catalystPath :+ cf.name, avroPath 
:+ cf.name)
+    }.toArray
+
+    val numBranches = catalystStruct.length
+    row: InternalRow => {
+      var idx = 0
+      var retVal: Any = null
+      while (idx < numBranches && retVal == null) {
+        if (!row.isNullAt(idx)) {
+          retVal = unionBranchConverters(idx).apply(row, idx)
+        }
+        idx += 1
+      }
+      retVal
+    }
+  }
+
   /**
    * Resolve a possibly nullable Avro Type.
    *
-   * An Avro type is nullable when it is a [[UNION]] of two types: one null 
type and another
-   * non-null type. This method will check the nullability of the input Avro 
type and return the
-   * non-null type within when it is nullable. Otherwise it will return the 
input Avro type
-   * unchanged. It will throw an [[UnsupportedAvroTypeException]] when the 
input Avro type is an
-   * unsupported nullable type.
+   * An Avro type is nullable when it is a [[UNION]] which contains a null 
type.  This method will
+   * check the nullability of the input Avro type.
+   * Returns the non-null type within the union when it contains only 1 
non-null type.
+   * Otherwise it will return the input Avro type unchanged.
    *
    * It will also log a warning message if the nullability for Avro and 
catalyst types are
    * different.
@@ -306,20 +362,18 @@ private[sql] class AvroSerializer(
   }
 
   /**
-   * Check the nullability of the input Avro type and resolve it when it is 
nullable. The first
-   * return value is a [[Boolean]] indicating if the input Avro type is 
nullable. The second
-   * return value is the possibly resolved type.
+   * Check the nullability of the input Avro type and resolve it when it is a 
single nullable type.
+   * The first return value is a [[Boolean]] indicating if the input Avro type 
is nullable.
+   * The second return value is the possibly resolved type otherwise the input 
Avro type unchanged.
    */
   private def resolveAvroType(avroType: Schema): (Boolean, Schema) = {
     if (avroType.getType == Type.UNION) {
-      val fields = avroType.getTypes.asScala
-      val actualType = fields.filter(_.getType != Type.NULL)
-      if (fields.length != 2 || actualType.length != 1) {
-        throw new UnsupportedAvroTypeException(
-          s"Unsupported Avro UNION type $avroType: Only UNION of a null type 
and a non-null " +
-            "type is supported")
+      val containsNull = avroType.getTypes.asScala.exists(_.getType == 
Schema.Type.NULL)
+      nonNullUnionBranches(avroType) match {
+        case Seq() => (true, Schema.create(Type.NULL))
+        case Seq(singleType) => (containsNull, singleType)
+        case _ => (containsNull, avroType)
       }
-      (true, actualType.head)
     } else {
       (false, avroType)
     }
@@ -337,4 +391,8 @@ private[sql] class AvroSerializer(
         "schema will throw runtime exception if there is a record with null 
value.")
     }
   }
+
+  private def nonNullUnionTypes(avroType: Schema): Set[Type] = {
+    nonNullUnionBranches(avroType).map(_.getType).toSet
+  }
 }
diff --git 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
index 45fa7450e45..e1966bd1041 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
@@ -336,4 +336,9 @@ private[sql] object AvroUtils extends Logging {
   private[avro] def isNullable(avroField: Schema.Field): Boolean =
     avroField.schema().getType == Schema.Type.UNION &&
       avroField.schema().getTypes.asScala.exists(_.getType == Schema.Type.NULL)
+
+  /** Collect all non null branches of a union in order. */
+  private[avro] def nonNullUnionBranches(avroType: Schema): Seq[Schema] = {
+    avroType.getTypes.asScala.filter(_.getType != Schema.Type.NULL).toSeq
+  }
 }
diff --git 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
index 375f8de3328..f616cfa9b5d 100644
--- 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
+++ 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
@@ -127,7 +127,7 @@ object SchemaConverters {
       case UNION =>
         if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
           // In case of a union with null, eliminate it and make a recursive 
call
-          val remainingUnionTypes = 
avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
+          val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema)
           if (remainingUnionTypes.size == 1) {
             toSqlTypeHelper(remainingUnionTypes.head, 
existingRecordNames).copy(nullable = true)
           } else {
diff --git 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala
 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala
index 9772033ed3f..7c79162e896 100644
--- 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala
+++ 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala
@@ -21,7 +21,7 @@ import java.io.ByteArrayOutputStream
 
 import scala.collection.JavaConverters._
 
-import org.apache.avro.Schema
+import org.apache.avro.{Schema, SchemaBuilder}
 import org.apache.avro.generic.{GenericDatumWriter, GenericRecord, 
GenericRecordBuilder}
 import org.apache.avro.io.EncoderFactory
 
@@ -220,26 +220,36 @@ class AvroFunctionsSuite extends QueryTest with 
SharedSparkSession {
       functions.from_avro($"avro", avroTypeStruct)), df)
   }
 
-  test("to_avro with unsupported nullable Avro schema") {
+  test("to_avro optional union Avro schema") {
     val df = spark.range(10).select(struct($"id", 
$"id".cast("string").as("str")).as("struct"))
-    for (unsupportedAvroType <- Seq("""["null", "int", "long"]""", """["int", 
"long"]""")) {
+    for (supportedAvroType <- Seq("""["null", "int", "long"]""", """["int", 
"long"]""")) {
       val avroTypeStruct = s"""
         |{
         |  "type": "record",
         |  "name": "struct",
         |  "fields": [
-        |    {"name": "id", "type": $unsupportedAvroType},
+        |    {"name": "id", "type": $supportedAvroType},
         |    {"name": "str", "type": ["null", "string"]}
         |  ]
         |}
       """.stripMargin
-      val message = intercept[SparkException] {
-        df.select(functions.to_avro($"struct", 
avroTypeStruct).as("avro")).show()
-      }.getCause.getMessage
-      assert(message.contains("Only UNION of a null type and a non-null type 
is supported"))
+      val avroStructDF = df.select(functions.to_avro($"struct", 
avroTypeStruct).as("avro"))
+      checkAnswer(avroStructDF.select(
+        functions.from_avro($"avro", avroTypeStruct)), df)
     }
   }
 
+  test("to_avro complex union Avro schema") {
+    val df = Seq((Some(1), None), (None, Some("a"))).toDF()
+      .select(struct(struct($"_1".as("member0"), 
$"_2".as("member1")).as("u")).as("struct"))
+    val avroTypeStruct = SchemaBuilder.record("struct").fields()
+      
.name("u").`type`().unionOf().intType().and().stringType().endUnion().noDefault()
+      .endRecord().toString
+    val avroStructDF = df.select(functions.to_avro($"struct", 
avroTypeStruct).as("avro"))
+    checkAnswer(avroStructDF.select(
+      functions.from_avro($"avro", avroTypeStruct)), df)
+  }
+
   test("SPARK-39775: Disable validate default values when parsing Avro 
schemas") {
     val avroTypeStruct = s"""
       |{
@@ -255,8 +265,8 @@ class AvroFunctionsSuite extends QueryTest with 
SharedSparkSession {
 
     val df = spark.range(5).select($"id")
     val structDf = df.select(struct($"id").as("struct"))
-    val avroStructDF = structDf.select(functions.to_avro('struct, 
avroTypeStruct).as("avro"))
-    checkAnswer(avroStructDF.select(functions.from_avro('avro, 
avroTypeStruct)), structDf)
+    val avroStructDF = structDf.select(functions.to_avro($"struct", 
avroTypeStruct).as("avro"))
+    checkAnswer(avroStructDF.select(functions.from_avro($"avro", 
avroTypeStruct)), structDf)
 
     withTempPath { dir =>
       df.write.format("avro").save(dir.getCanonicalPath)
diff --git 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index debdf9b45cf..d19a11b4546 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -299,21 +299,27 @@ abstract class AvroSuite
 
   test("Complex Union Type") {
     withTempPath { dir =>
-      val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4)
-      val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", 
List("e1", "e2").asJava)
-      val complexUnionType = Schema.createUnion(
-        List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, 
enumSchema).asJava)
-      val fields = Seq(
-        new Field("field1", complexUnionType, "doc", 
null.asInstanceOf[AnyVal]),
-        new Field("field2", complexUnionType, "doc", 
null.asInstanceOf[AnyVal]),
-        new Field("field3", complexUnionType, "doc", 
null.asInstanceOf[AnyVal]),
-        new Field("field4", complexUnionType, "doc", null.asInstanceOf[AnyVal])
-      ).asJava
-      val schema = Schema.createRecord("name", "docs", "namespace", false)
-      schema.setFields(fields)
+      val nativeWriterPath = s"$dir.avro"
+      val sparkWriterPath = s"$dir/spark"
+      val fixedSchema = SchemaBuilder.fixed("fixed_name").size(4)
+      val enumSchema = SchemaBuilder.enumeration("enum_name").symbols("e1", 
"e2")
+      val complexUnionType = SchemaBuilder.unionOf()
+          .intType().and()
+          .stringType().and()
+          .`type`(fixedSchema).and()
+          .`type`(enumSchema).and()
+          .nullType()
+        .endUnion()
+      val schema = SchemaBuilder.record("name").fields()
+          .name("field1").`type`(complexUnionType).noDefault()
+          .name("field2").`type`(complexUnionType).noDefault()
+          .name("field3").`type`(complexUnionType).noDefault()
+          .name("field4").`type`(complexUnionType).noDefault()
+          .name("field5").`type`(complexUnionType).noDefault()
+        .endRecord()
       val datumWriter = new GenericDatumWriter[GenericRecord](schema)
       val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter)
-      dataFileWriter.create(schema, new File(s"$dir.avro"))
+      dataFileWriter.create(schema, new File(nativeWriterPath))
       val avroRec = new GenericData.Record(schema)
       val field1 = 1234
       val field2 = "Hope that was not load bearing"
@@ -323,15 +329,32 @@ abstract class AvroSuite
       avroRec.put("field2", field2)
       avroRec.put("field3", new Fixed(fixedSchema, field3))
       avroRec.put("field4", new EnumSymbol(enumSchema, field4))
+      avroRec.put("field5", null)
       dataFileWriter.append(avroRec)
       dataFileWriter.flush()
       dataFileWriter.close()
 
-      val df = spark.sqlContext.read.format("avro").load(s"$dir.avro")
-      assertResult(field1)(df.selectExpr("field1.member0").first().get(0))
-      assertResult(field2)(df.selectExpr("field2.member1").first().get(0))
-      assertResult(field3)(df.selectExpr("field3.member2").first().get(0))
-      assertResult(field4)(df.selectExpr("field4.member3").first().get(0))
+      val df = spark.sqlContext.read.format("avro").load(nativeWriterPath)
+      assertResult(Row(field1, null, null, 
null))(df.selectExpr("field1.*").first())
+      assertResult(Row(null, field2, null, 
null))(df.selectExpr("field2.*").first())
+      assertResult(Row(null, null, field3, 
null))(df.selectExpr("field3.*").first())
+      assertResult(Row(null, null, null, 
field4))(df.selectExpr("field4.*").first())
+      assertResult(Row(null, null, null, 
null))(df.selectExpr("field5.*").first())
+
+      df.write.format("avro").option("avroSchema", 
schema.toString).save(sparkWriterPath)
+
+      val df2 = spark.sqlContext.read.format("avro").load(nativeWriterPath)
+      assertResult(Row(field1, null, null, 
null))(df2.selectExpr("field1.*").first())
+      assertResult(Row(null, field2, null, 
null))(df2.selectExpr("field2.*").first())
+      assertResult(Row(null, null, field3, 
null))(df2.selectExpr("field3.*").first())
+      assertResult(Row(null, null, null, 
field4))(df2.selectExpr("field4.*").first())
+      assertResult(Row(null, null, null, 
null))(df2.selectExpr("field5.*").first())
+
+      val reader = openDatumReader(new File(sparkWriterPath))
+      assert(reader.hasNext)
+      assertResult(avroRec)(reader.next())
+      assert(!reader.hasNext)
+      reader.close()
     }
   }
 
@@ -1143,32 +1166,81 @@ abstract class AvroSuite
     }
   }
 
-  test("unsupported nullable avro type") {
+  test("int/long double/float conversion") {
     val catalystSchema =
       StructType(Seq(
-        StructField("Age", IntegerType, nullable = false),
-        StructField("Name", StringType, nullable = false)))
+        StructField("Age", LongType),
+        StructField("Length", DoubleType),
+        StructField("Name", StringType)))
 
-    for (unsupportedAvroType <- Seq("""["null", "int", "long"]""", """["int", 
"long"]""")) {
+    for (optionalNull <- Seq(""""null",""", "")) {
       val avroSchema = s"""
         |{
         |  "type" : "record",
         |  "name" : "test_schema",
         |  "fields" : [
-        |    {"name": "Age", "type": $unsupportedAvroType},
+        |    {"name": "Age", "type": [$optionalNull "int", "long"]},
+        |    {"name": "Length", "type": [$optionalNull "float", "double"]},
         |    {"name": "Name", "type": ["null", "string"]}
         |  ]
         |}
       """.stripMargin
 
       val df = spark.createDataFrame(
-        spark.sparkContext.parallelize(Seq(Row(2, "Aurora"))), catalystSchema)
+        spark.sparkContext.parallelize(Seq(Row(2L, 1.8D, "Aurora"), Row(1L, 
0.9D, null))),
+        catalystSchema)
+
+      withTempPath { tempDir =>
+        df.write.format("avro").option("avroSchema", 
avroSchema).save(tempDir.getPath)
+        checkAnswer(
+          spark.read
+            .format("avro")
+            .option("avroSchema", avroSchema)
+            .load(tempDir.getPath),
+          df)
+      }
+    }
+  }
+
+  test("non-matching complex union types") {
+    val catalystSchema = new StructType().add("Union", new StructType()
+      .add("member0", IntegerType)
+      .add("member1", new StructType().add("f1", StringType, nullable = false))
+    )
+
+    val df = spark.createDataFrame(
+      spark.sparkContext.parallelize(Seq(Row(Row(1, null)))), catalystSchema)
+
+    val recordS = 
SchemaBuilder.record("r").fields().requiredString("f1").endRecord()
+    val intS = Schema.create(Schema.Type.INT)
+    val nullS = Schema.create(Schema.Type.NULL)
+    for ((unionTypes, compatible) <- Seq(
+      (Seq(nullS, intS, recordS), true),
+      (Seq(intS, nullS, recordS), true),
+      (Seq(intS, recordS, nullS), true),
+      (Seq(intS, recordS), true),
+      (Seq(nullS, recordS, intS), false),
+      (Seq(nullS, recordS), false),
+      (Seq(nullS, 
SchemaBuilder.record("r").fields().requiredString("f2").endRecord()), false)
+    )) {
+      val avroSchema = SchemaBuilder.record("test_schema").fields()
+        .name("union").`type`(Schema.createUnion(unionTypes: _*)).noDefault()
+        .endRecord().toString()
 
       withTempPath { tempDir =>
-        val message = intercept[SparkException] {
+        if (!compatible) {
+          intercept[SparkException] {
+            df.write.format("avro").option("avroSchema", 
avroSchema).save(tempDir.getPath)
+          }
+        } else {
           df.write.format("avro").option("avroSchema", 
avroSchema).save(tempDir.getPath)
-        }.getMessage
-        assert(message.contains("Only UNION of a null type and a non-null type 
is supported"))
+          checkAnswer(
+            spark.read
+              .format("avro")
+              .option("avroSchema", avroSchema)
+              .load(tempDir.getPath),
+            df)
+        }
       }
     }
   }
@@ -2104,12 +2176,15 @@ abstract class AvroSuite
   }
 
   private def checkMetaData(path: java.io.File, key: String, expectedValue: 
String): Unit = {
+    val value = 
openDatumReader(path).asInstanceOf[DataFileReader[_]].getMetaString(key)
+    assert(value === expectedValue)
+  }
+
+  private def openDatumReader(path: File): 
org.apache.avro.file.FileReader[GenericRecord] = {
     val avroFiles = path.listFiles()
       .filter(f => f.isFile && !f.getName.startsWith(".") && 
!f.getName.startsWith("_"))
     assert(avroFiles.length === 1)
-    val reader = DataFileReader.openReader(avroFiles(0), new 
GenericDatumReader[GenericRecord]())
-    val value = reader.asInstanceOf[DataFileReader[_]].getMetaString(key)
-    assert(value === expectedValue)
+    DataFileReader.openReader(avroFiles(0), new 
GenericDatumReader[GenericRecord]())
   }
 
   test("SPARK-31327: Write Spark version into Avro file metadata") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to