Github user dbtsai commented on a diff in the pull request: https://github.com/apache/spark/pull/21847#discussion_r209117054 --- Diff: external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala --- @@ -725,6 +744,205 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(result === Row("foo")) } + test("support user provided avro schema for writing nullable enum type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "enum", + | "type": [{ "type": "enum", + | "name": "Suit", + | "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + | }, "null"] + | }] + |} + """.stripMargin + + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row(null), Row("HEARTS"), Row("DIAMONDS"), + Row(null), Row("CLUBS"), Row("HEARTS"), Row("SPADES"))), + StructType(Seq(StructField("Suit", StringType, true)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing data not in the enum will throw an exception + val message = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row("NOT-IN-ENUM"), Row("HEARTS"), Row("DIAMONDS"))), + StructType(Seq(StructField("Suit", StringType, true)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write \"NOT-IN-ENUM\" since it's not defined in enum")) + } + } + + test("support user provided avro schema for writing non-nullable enum type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "enum", + | "type": { "type": "enum", + | "name": "Suit", + | "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + | } + | }] + |} + """.stripMargin + + val dfWithNull = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row(null), Row("HEARTS"), Row("DIAMONDS"), + Row(null), Row("CLUBS"), Row("HEARTS"), Row("SPADES"))), + StructType(Seq(StructField("Suit", StringType, true)))) + + val df = spark.createDataFrame(dfWithNull.na.drop().rdd, + StructType(Seq(StructField("Suit", StringType, false)))) + + val tempSaveDir = s"$tempDir/save1/" --- End diff -- addressed
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org