panbingkun commented on code in PR #47888:
URL: https://github.com/apache/spark/pull/47888#discussion_r1735478092


##########
connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala:
##########
@@ -371,4 +372,188 @@ class AvroFunctionsSuite extends QueryTest with 
SharedSparkSession {
           stop = 138)))
     }
   }
+
+  private def serialize(record: GenericRecord, avroSchema: String): 
Array[Byte] = {
+    val schema = new Schema.Parser().parse(avroSchema)
+    val datumWriter = new GenericDatumWriter[GenericRecord](schema)
+    var outputStream: ByteArrayOutputStream = null
+    var bytes: Array[Byte] = null
+    try {
+      outputStream = new ByteArrayOutputStream()
+      val encoder = EncoderFactory.get.binaryEncoder(outputStream, null)
+      datumWriter.write(record, encoder)
+      encoder.flush()
+      bytes = outputStream.toByteArray
+    } finally {
+      if (outputStream != null) {
+        outputStream.close()
+      }
+    }
+    bytes
+  }
+
+  private def deserialize(bytes: Array[Byte], avroSchema: String): 
GenericRecord = {
+    val schema = new Schema.Parser().parse(avroSchema)
+    val datumReader = new GenericDatumReader[GenericRecord](schema)
+    var inputStream: SeekableByteArrayInput = null
+    var record: GenericRecord = null
+    try {
+      inputStream = new SeekableByteArrayInput(bytes)
+      val decoder = DecoderFactory.get.binaryDecoder(inputStream, null)
+      record = datumReader.read(null, decoder)
+    } finally {
+      if (inputStream != null) {
+        inputStream.close()
+      }
+    }
+    record
+  }
+
+  test("GenericRecord serialize/deserialize") {
+    val avroSchema =
+      """
+        |{
+        |  "type": "record",
+        |  "name": "person",
+        |  "fields": [
+        |    {"name": "name", "type": "string"},
+        |    {"name": "age", "type": "int"},
+        |    {"name": "country", "type": "string"}
+        |  ]
+        |}
+        |""".stripMargin
+    val schema = new Schema.Parser().parse(avroSchema)
+    val person = new GenericRecordBuilder(schema)
+      .set("name", "spark")
+      .set("age", 18)
+      .set("country", "usa")
+      .build()
+    val bytes = serialize(person, avroSchema)
+    val readback = deserialize(bytes, avroSchema)
+    assert(person.get("name").toString === readback.get("name").toString)
+    assert(person.get("age") === readback.get("age"))
+    assert(person.get("country").toString === readback.get("country").toString)
+  }
+
+  test("use `from_avro` to read GenericRecord(stored in `array[Byte]` 
datatype)") {
+    // write: Person (avro `GenericRecord`) -> binary (serialize) -> dataframe
+    // read: dataframe -> from_avro -> struct -> Person (avro `GenericRecord`)
+    val avroSchema =
+      """
+        |{
+        |  "type": "record",
+        |  "name": "person",
+        |  "fields": [
+        |    {"name": "name", "type": "string"},
+        |    {"name": "age", "type": "int"},
+        |    {"name": "country", "type": "string"}
+        |  ]
+        |}
+        |""".stripMargin
+    val testTable = "test_avro"
+    withTable(testTable) {
+      val schema = new Schema.Parser().parse(avroSchema)
+      val person1 = new GenericRecordBuilder(schema)
+        .set("name", "sparkA")
+        .set("age", 18)
+        .set("country", "usa")
+        .build()
+      val person2 = new GenericRecordBuilder(schema)
+        .set("name", "sparkB")
+        .set("age", 19)
+        .set("country", "usb")
+        .build()
+      Seq(person1, person2)
+        .map(p => serialize(p, avroSchema))
+        .toDF("data")
+        .repartition(1)
+        .writeTo(testTable)
+        .create()
+
+      val expectedSchema = new StructType().add("data", BinaryType)
+      assert(spark.table(testTable).schema === expectedSchema)
+
+      val avroDF = sql(s"SELECT from_avro(data, '$avroSchema', map()) FROM 
$testTable")
+      val readbacks = avroDF
+        .collect()
+        .map(row =>
+          new GenericRecordBuilder(schema)
+            .set("name", row.getStruct(0).getString(0))
+            .set("age", row.getStruct(0).getInt(1))
+            .set("country", row.getStruct(0).getString(2))
+            .build())
+
+      val readbackPerson1 = readbacks.head
+      assert(readbackPerson1.get(0) === person1.get(0))
+      assert(readbackPerson1.get(1).asInstanceOf[Int] === 
person1.get(1).asInstanceOf[Int])
+      assert(readbackPerson1.get(2).toString === person1.get(2))
+
+      val readbackPerson2 = readbacks(1)
+      assert(readbackPerson2.get(0).toString === person2.get(0))
+      assert(readbackPerson2.get(1).asInstanceOf[Int] === 
person2.get(1).asInstanceOf[Int])
+      assert(readbackPerson2.get(2).toString === person2.get(2))
+    }
+  }
+
+  test("use `to_avro` to read GenericRecord(stored in `struct` datatype)") {

Review Comment:
   I have refactored this code and it seems clearer now. Please help review it 
in your free time, thanks!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to