Repository: spark
Updated Branches:
  refs/heads/master bdd27961c -> 0cea9e3cd


[SPARK-24855][SQL][EXTERNAL] Built-in AVRO support should support specified 
schema on write

## What changes were proposed in this pull request?

Allows `avroSchema` option to be specified on write, allowing a user to specify 
a schema in cases where this is required.  A trivial use case is reading in an 
avro dataset, making some small adjustment to a column or columns and writing 
out using the same schema.  Implicit schema creation from SQL Struct results in 
a schema that while for the most part, is functionally similar, is not 
necessarily compatible.

Allows `fixed` Field type to be utilized for records of specified `avroSchema`

## How was this patch tested?

Unit tests in AvroSuite are extended to test this with enum and fixed types.

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Closes #21847 from lindblombr/specify_schema_on_write.

Lead-authored-by: Brian Lindblom <blindb...@apple.com>
Co-authored-by: DB Tsai <d_t...@apple.com>
Signed-off-by: DB Tsai <d_t...@apple.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0cea9e3c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0cea9e3c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0cea9e3c

Branch: refs/heads/master
Commit: 0cea9e3cd0a92799bdcc0f9bc2cf96259c343a30
Parents: bdd2796
Author: Brian Lindblom <blindb...@apple.com>
Authored: Fri Aug 10 03:35:29 2018 +0000
Committer: DB Tsai <d_t...@apple.com>
Committed: Fri Aug 10 03:35:29 2018 +0000

----------------------------------------------------------------------
 .../apache/spark/sql/avro/AvroFileFormat.scala  |   6 +-
 .../apache/spark/sql/avro/AvroSerializer.scala  |  40 +++-
 .../org/apache/spark/sql/avro/AvroSuite.scala   | 228 ++++++++++++++++++-
 3 files changed, 257 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0cea9e3c/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
----------------------------------------------------------------------
diff --git 
a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala 
b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
index 6ffcf37..6df23c9 100755
--- 
a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
+++ 
b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
@@ -113,8 +113,10 @@ private[avro] class AvroFileFormat extends FileFormat
       options: Map[String, String],
       dataSchema: StructType): OutputWriterFactory = {
     val parsedOptions = new AvroOptions(options, 
spark.sessionState.newHadoopConf())
-    val outputAvroSchema = SchemaConverters.toAvroType(dataSchema, nullable = 
false,
-      parsedOptions.recordName, parsedOptions.recordNamespace, 
parsedOptions.outputTimestampType)
+    val outputAvroSchema: Schema = parsedOptions.schema
+      .map(new Schema.Parser().parse)
+      .getOrElse(SchemaConverters.toAvroType(dataSchema, nullable = false,
+        parsedOptions.recordName, parsedOptions.recordNamespace))
 
     AvroJob.setOutputKeySchema(job, outputAvroSchema)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0cea9e3c/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
----------------------------------------------------------------------
diff --git 
a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala 
b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
index 9885826..216c52a 100644
--- 
a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
+++ 
b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -23,8 +23,8 @@ import scala.collection.JavaConverters._
 
 import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
 import org.apache.avro.Schema
-import org.apache.avro.Schema.Type.NULL
-import org.apache.avro.generic.GenericData.Record
+import org.apache.avro.Schema.Type
+import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record}
 import org.apache.avro.util.Utf8
 
 import org.apache.spark.sql.catalyst.InternalRow
@@ -87,10 +87,36 @@ class AvroSerializer(rootCatalystType: DataType, 
rootAvroType: Schema, nullable:
         (getter, ordinal) => getter.getDouble(ordinal)
       case d: DecimalType =>
         (getter, ordinal) => getter.getDecimal(ordinal, d.precision, 
d.scale).toString
-      case StringType =>
-        (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)
-      case BinaryType =>
-        (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
+      case StringType => avroType.getType match {
+        case Type.ENUM =>
+          import scala.collection.JavaConverters._
+          val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet
+          (getter, ordinal) =>
+            val data = getter.getUTF8String(ordinal).toString
+            if (!enumSymbols.contains(data)) {
+              throw new IncompatibleSchemaException(
+                "Cannot write \"" + data + "\" since it's not defined in enum 
\"" +
+                  enumSymbols.mkString("\", \"") + "\"")
+            }
+            new EnumSymbol(avroType, data)
+        case _ =>
+          (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)
+      }
+      case BinaryType => avroType.getType match {
+        case Type.FIXED =>
+          val size = avroType.getFixedSize()
+          (getter, ordinal) =>
+            val data: Array[Byte] = getter.getBinary(ordinal)
+            if (data.length != size) {
+              throw new IncompatibleSchemaException(
+                s"Cannot write ${data.length} ${if (data.length > 1) "bytes" 
else "byte"} of " +
+                  "binary data into FIXED Type with size of " +
+                  s"$size ${if (size > 1) "bytes" else "byte"}")
+            }
+            new Fixed(avroType, data)
+        case _ =>
+          (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
+      }
       case DateType =>
         (getter, ordinal) => getter.getInt(ordinal)
       case TimestampType => avroType.getLogicalType match {
@@ -182,7 +208,7 @@ class AvroSerializer(rootCatalystType: DataType, 
rootAvroType: Schema, nullable:
       // avro uses union to represent nullable type.
       val fields = avroType.getTypes.asScala
       assert(fields.length == 2)
-      val actualType = fields.filter(_.getType != NULL)
+      val actualType = fields.filter(_.getType != Type.NULL)
       assert(actualType.length == 1)
       actualType.head
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/0cea9e3c/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
----------------------------------------------------------------------
diff --git 
a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala 
b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index 47995bb..ada9980 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -32,6 +32,7 @@ import org.apache.avro.generic.{GenericData, 
GenericDatumReader, GenericDatumWri
 import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
 import org.apache.commons.io.FileUtils
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.datasources.DataSource
@@ -100,6 +101,25 @@ class AvroSuite extends QueryTest with SharedSQLContext 
with SQLTestUtils {
     checkAnswer(newEntries, originalEntries)
   }
 
+  def checkAvroSchemaEquals(avroSchema: String, expectedAvroSchema: String): 
Unit = {
+    assert(new Schema.Parser().parse(avroSchema) ==
+      new Schema.Parser().parse(expectedAvroSchema))
+  }
+
+  def getAvroSchemaStringFromFiles(filePath: String): String = {
+    new DataFileReader({
+      val file = new File(filePath)
+      if (file.isFile) {
+        file
+      } else {
+        file.listFiles()
+          .filter(_.isFile)
+          .filter(_.getName.endsWith("avro"))
+          .head
+      }
+    }, new GenericDatumReader[Any]()).getSchema.toString(false)
+  }
+
   test("resolve avro data source") {
     Seq("avro", "com.databricks.spark.avro").foreach { provider =>
       assert(DataSource.lookupDataSource(provider, spark.sessionState.conf) ===
@@ -471,7 +491,6 @@ class AvroSuite extends QueryTest with SharedSQLContext 
with SQLTestUtils {
       }
     """
     val df = spark.read.format("avro").option("avroSchema", 
avroSchema).load(timestampAvro)
-
     checkAnswer(df, expected)
   }
 
@@ -773,6 +792,205 @@ class AvroSuite extends QueryTest with SharedSQLContext 
with SQLTestUtils {
     assert(result === Row("foo"))
   }
 
+  test("support user provided avro schema for writing nullable enum type") {
+    withTempPath { tempDir =>
+      val avroSchema =
+        """
+          |{
+          |  "type" : "record",
+          |  "name" : "test_schema",
+          |  "fields" : [{
+          |    "name": "enum",
+          |    "type": [{ "type": "enum",
+          |              "name": "Suit",
+          |              "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]
+          |            }, "null"]
+          |  }]
+          |}
+        """.stripMargin
+
+      val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq(
+        Row("SPADES"), Row(null), Row("HEARTS"), Row("DIAMONDS"),
+        Row(null), Row("CLUBS"), Row("HEARTS"), Row("SPADES"))),
+        StructType(Seq(StructField("Suit", StringType, true))))
+
+      val tempSaveDir = s"$tempDir/save/"
+
+      df.write.format("avro").option("avroSchema", 
avroSchema).save(tempSaveDir)
+
+      checkAnswer(df, spark.read.format("avro").load(tempSaveDir))
+      checkAvroSchemaEquals(avroSchema, 
getAvroSchemaStringFromFiles(tempSaveDir))
+
+      // Writing df containing data not in the enum will throw an exception
+      val message = intercept[SparkException] {
+        spark.createDataFrame(spark.sparkContext.parallelize(Seq(
+          Row("SPADES"), Row("NOT-IN-ENUM"), Row("HEARTS"), Row("DIAMONDS"))),
+          StructType(Seq(StructField("Suit", StringType, true))))
+          .write.format("avro").option("avroSchema", avroSchema)
+          .save(s"$tempDir/${UUID.randomUUID()}")
+      }.getCause.getMessage
+      
assert(message.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: 
" +
+        "Cannot write \"NOT-IN-ENUM\" since it's not defined in enum"))
+    }
+  }
+
+  test("support user provided avro schema for writing non-nullable enum type") 
{
+    withTempPath { tempDir =>
+      val avroSchema =
+        """
+          |{
+          |  "type" : "record",
+          |  "name" : "test_schema",
+          |  "fields" : [{
+          |    "name": "enum",
+          |    "type": { "type": "enum",
+          |              "name": "Suit",
+          |              "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]
+          |            }
+          |  }]
+          |}
+        """.stripMargin
+
+      val dfWithNull = 
spark.createDataFrame(spark.sparkContext.parallelize(Seq(
+        Row("SPADES"), Row(null), Row("HEARTS"), Row("DIAMONDS"),
+        Row(null), Row("CLUBS"), Row("HEARTS"), Row("SPADES"))),
+        StructType(Seq(StructField("Suit", StringType, true))))
+
+      val df = spark.createDataFrame(dfWithNull.na.drop().rdd,
+        StructType(Seq(StructField("Suit", StringType, false))))
+
+      val tempSaveDir = s"$tempDir/save/"
+
+      df.write.format("avro").option("avroSchema", 
avroSchema).save(tempSaveDir)
+
+      checkAnswer(df, spark.read.format("avro").load(tempSaveDir))
+      checkAvroSchemaEquals(avroSchema, 
getAvroSchemaStringFromFiles(tempSaveDir))
+
+      // Writing df containing nulls without using avro union type will
+      // throw an exception as avro uses union type to handle null.
+      val message1 = intercept[SparkException] {
+        dfWithNull.write.format("avro")
+          .option("avroSchema", 
avroSchema).save(s"$tempDir/${UUID.randomUUID()}")
+      }.getCause.getMessage
+      assert(message1.contains("org.apache.avro.AvroRuntimeException: Not a 
union:"))
+
+      // Writing df containing data not in the enum will throw an exception
+      val message2 = intercept[SparkException] {
+        spark.createDataFrame(spark.sparkContext.parallelize(Seq(
+          Row("SPADES"), Row("NOT-IN-ENUM"), Row("HEARTS"), Row("DIAMONDS"))),
+          StructType(Seq(StructField("Suit", StringType, false))))
+          .write.format("avro").option("avroSchema", avroSchema)
+          .save(s"$tempDir/${UUID.randomUUID()}")
+      }.getCause.getMessage
+      
assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException:
 " +
+        "Cannot write \"NOT-IN-ENUM\" since it's not defined in enum"))
+    }
+  }
+
+  test("support user provided avro schema for writing nullable fixed type") {
+    withTempPath { tempDir =>
+      val avroSchema =
+        """
+          |{
+          |  "type" : "record",
+          |  "name" : "test_schema",
+          |  "fields" : [{
+          |    "name": "fixed2",
+          |    "type": [{ "type": "fixed",
+          |               "size": 2,
+          |               "name": "fixed2"
+          |            }, "null"]
+          |  }]
+          |}
+        """.stripMargin
+
+      val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq(
+        Row(Array(192, 168).map(_.toByte)), Row(null))),
+        StructType(Seq(StructField("fixed2", BinaryType, true))))
+
+      val tempSaveDir = s"$tempDir/save/"
+
+      df.write.format("avro").option("avroSchema", 
avroSchema).save(tempSaveDir)
+
+      checkAnswer(df, spark.read.format("avro").load(tempSaveDir))
+      checkAvroSchemaEquals(avroSchema, 
getAvroSchemaStringFromFiles(tempSaveDir))
+
+      // Writing df containing binary data that doesn't fit FIXED size will 
throw an exception
+      val message1 = intercept[SparkException] {
+        spark.createDataFrame(spark.sparkContext.parallelize(Seq(
+          Row(Array(192, 168, 1).map(_.toByte)))),
+          StructType(Seq(StructField("fixed2", BinaryType, true))))
+          .write.format("avro").option("avroSchema", avroSchema)
+          .save(s"$tempDir/${UUID.randomUUID()}")
+      }.getCause.getMessage
+      
assert(message1.contains("org.apache.spark.sql.avro.IncompatibleSchemaException:
 " +
+        "Cannot write 3 bytes of binary data into FIXED Type with size of 2 
bytes"))
+
+      // Writing df containing binary data that doesn't fit FIXED size will 
throw an exception
+      val message2 = intercept[SparkException] {
+        spark.createDataFrame(spark.sparkContext.parallelize(Seq(
+          Row(Array(192).map(_.toByte)))),
+          StructType(Seq(StructField("fixed2", BinaryType, true))))
+          .write.format("avro").option("avroSchema", avroSchema)
+          .save(s"$tempDir/${UUID.randomUUID()}")
+      }.getCause.getMessage
+      
assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException:
 " +
+        "Cannot write 1 byte of binary data into FIXED Type with size of 2 
bytes"))
+    }
+  }
+
+  test("support user provided avro schema for writing non-nullable fixed 
type") {
+    withTempPath { tempDir =>
+      val avroSchema =
+        """
+          |{
+          |  "type" : "record",
+          |  "name" : "test_schema",
+          |  "fields" : [{
+          |    "name": "fixed2",
+          |    "type": { "type": "fixed",
+          |               "size": 2,
+          |               "name": "fixed2"
+          |            }
+          |  }]
+          |}
+        """.stripMargin
+
+      val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq(
+        Row(Array(192, 168).map(_.toByte)), Row(Array(1, 1).map(_.toByte)))),
+        StructType(Seq(StructField("fixed2", BinaryType, false))))
+
+      val tempSaveDir = s"$tempDir/save/"
+
+      df.write.format("avro").option("avroSchema", 
avroSchema).save(tempSaveDir)
+
+      checkAnswer(df, spark.read.format("avro").load(tempSaveDir))
+      checkAvroSchemaEquals(avroSchema, 
getAvroSchemaStringFromFiles(tempSaveDir))
+
+      // Writing df containing binary data that doesn't fit FIXED size will 
throw an exception
+      val message1 = intercept[SparkException] {
+        spark.createDataFrame(spark.sparkContext.parallelize(Seq(
+          Row(Array(192, 168, 1).map(_.toByte)))),
+          StructType(Seq(StructField("fixed2", BinaryType, false))))
+          .write.format("avro").option("avroSchema", avroSchema)
+          .save(s"$tempDir/${UUID.randomUUID()}")
+      }.getCause.getMessage
+      
assert(message1.contains("org.apache.spark.sql.avro.IncompatibleSchemaException:
 " +
+        "Cannot write 3 bytes of binary data into FIXED Type with size of 2 
bytes"))
+
+      // Writing df containing binary data that doesn't fit FIXED size will 
throw an exception
+      val message2 = intercept[SparkException] {
+        spark.createDataFrame(spark.sparkContext.parallelize(Seq(
+          Row(Array(192).map(_.toByte)))),
+          StructType(Seq(StructField("fixed2", BinaryType, false))))
+          .write.format("avro").option("avroSchema", avroSchema)
+          .save(s"$tempDir/${UUID.randomUUID()}")
+      }.getCause.getMessage
+      
assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException:
 " +
+        "Cannot write 1 byte of binary data into FIXED Type with size of 2 
bytes"))
+    }
+  }
+
   test("reading from invalid path throws exception") {
 
     // Directory given has no avro files
@@ -936,13 +1154,7 @@ class AvroSuite extends QueryTest with SharedSQLContext 
with SQLTestUtils {
     withTempPath { dir =>
       val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, 
NestedBottom(3, "1")))))
       writeDf.write.format("avro").save(dir.toString)
-      val file = new File(dir.toString)
-        .listFiles()
-        .filter(_.isFile)
-        .filter(_.getName.endsWith("avro"))
-        .head
-      val reader = new DataFileReader(file, new GenericDatumReader[Any]())
-      val schema = reader.getSchema.toString()
+      val schema = getAvroSchemaStringFromFiles(dir.toString)
       assert(schema.contains("\"namespace\":\"topLevelRecord\""))
       assert(schema.contains("\"namespace\":\"topLevelRecord.data\""))
       assert(schema.contains("\"namespace\":\"topLevelRecord.data.data\""))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to