Github user bdrillard commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22878#discussion_r229329920
  
    --- Diff: 
external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala ---
    @@ -1374,4 +1377,182 @@ class AvroSuite extends QueryTest with 
SharedSQLContext with SQLTestUtils {
           |}
         """.stripMargin)
       }
    +
    +  test("generic record converts to row and back") {
    +    val nested =
    +      SchemaBuilder.record("simple_record").fields()
    +        .name("nested1").`type`("int").withDefault(0)
    +        .name("nested2").`type`("string").withDefault("string").endRecord()
    +    val schema = SchemaBuilder.record("record").fields()
    +      .name("boolean").`type`("boolean").withDefault(false)
    +      .name("int").`type`("int").withDefault(0)
    +      .name("long").`type`("long").withDefault(0L)
    +      .name("float").`type`("float").withDefault(0.0F)
    +      .name("double").`type`("double").withDefault(0.0)
    +      .name("string").`type`("string").withDefault("string")
    +      
.name("bytes").`type`("bytes").withDefault(java.nio.ByteBuffer.wrap("bytes".getBytes))
    +      .name("nested").`type`(nested).withDefault(new 
GenericRecordBuilder(nested).build)
    +      .name("enum").`type`(
    +      SchemaBuilder.enumeration("simple_enums")
    +        .symbols("SPADES", "HEARTS", "CLUBS", "DIAMONDS"))
    +      .withDefault("SPADES")
    +      .name("int_array").`type`(
    +      SchemaBuilder.array().items().`type`("int"))
    +      .withDefault(java.util.Arrays.asList(1, 2, 3))
    +      .name("string_array").`type`(
    +      SchemaBuilder.array().items().`type`("string"))
    +      .withDefault(java.util.Arrays.asList("a", "b", "c"))
    +      .name("record_array").`type`(
    +      SchemaBuilder.array.items.`type`(nested))
    +      .withDefault(java.util.Arrays.asList(
    +        new GenericRecordBuilder(nested).build,
    +        new GenericRecordBuilder(nested).build))
    +      .name("enum_array").`type`(
    +      SchemaBuilder.array.items.`type`(
    +        SchemaBuilder.enumeration("simple_enums")
    +          .symbols("SPADES", "HEARTS", "CLUBS", "DIAMONDS")))
    +      .withDefault(java.util.Arrays.asList("SPADES", "HEARTS", "SPADES"))
    +      .name("fixed_array").`type`(
    +      SchemaBuilder.array.items().`type`(
    +        SchemaBuilder.fixed("simple_fixed").size(3)))
    +      .withDefault(java.util.Arrays.asList("foo", "bar", "baz"))
    +      .name("fixed").`type`(SchemaBuilder.fixed("simple_fixed").size(16))
    +      .withDefault("string_length_16")
    +      .endRecord()
    +    val encoder = AvroEncoder.of[GenericData.Record](schema)
    +    val expressionEncoder = 
encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]]
    +    val record = new GenericRecordBuilder(schema).build
    +    val row = expressionEncoder.toRow(record)
    +    val recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
    +    assert(record == recordFromRow)
    +  }
    +
    +  test("encoder resolves union types to rows") {
    +    val schema = SchemaBuilder.record("record").fields()
    +      .name("int_null_union").`type`(
    +      SchemaBuilder.unionOf.`type`("null").and.`type`("int").endUnion)
    +      .withDefault(null)
    +      .name("string_null_union").`type`(
    +      SchemaBuilder.unionOf.`type`("null").and.`type`("string").endUnion)
    +      .withDefault(null)
    +      .name("int_long_union").`type`(
    +      SchemaBuilder.unionOf.`type`("int").and.`type`("long").endUnion)
    +      .withDefault(0)
    +      .name("float_double_union").`type`(
    +      SchemaBuilder.unionOf.`type`("float").and.`type`("double").endUnion)
    +      .withDefault(0.0)
    +      .endRecord
    +    val encoder = AvroEncoder.of[GenericData.Record](schema)
    +    val expressionEncoder = 
encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]]
    +    val record = new GenericRecordBuilder(schema).build
    +    val row = expressionEncoder.toRow(record)
    +    val recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
    +    assert(record.get(0) == recordFromRow.get(0))
    +    assert(record.get(1) == recordFromRow.get(1))
    +    assert(record.get(2) == recordFromRow.get(2))
    +    assert(record.get(3) == recordFromRow.get(3))
    +    record.put(0, 0)
    +    record.put(1, "value")
    +    val updatedRow = expressionEncoder.toRow(record)
    +    val updatedRecordFromRow = 
expressionEncoder.resolveAndBind().fromRow(updatedRow)
    +    assert(record.get(0) == updatedRecordFromRow.get(0))
    +    assert(record.get(1) == updatedRecordFromRow.get(1))
    +  }
    +
    +  test("encoder resolves complex unions to rows") {
    +    val nested =
    +      SchemaBuilder.record("simple_record").fields()
    +        .name("nested1").`type`("int").withDefault(0)
    +        .name("nested2").`type`("string").withDefault("foo").endRecord()
    +    val schema = SchemaBuilder.record("record").fields()
    +      .name("int_float_string_record").`type`(
    +      SchemaBuilder.unionOf()
    +        .`type`("null").and()
    +        .`type`("int").and()
    +        .`type`("float").and()
    +        .`type`("string").and()
    +        .`type`(nested).endUnion()
    +    ).withDefault(null).endRecord()
    +
    +    val encoder = AvroEncoder.of[GenericData.Record](schema)
    +    val expressionEncoder = 
encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]]
    +    val record = new GenericRecordBuilder(schema).build
    +    var row = expressionEncoder.toRow(record)
    +    var recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
    +
    +    assert(row.getStruct(0, 4).get(0, IntegerType) == null)
    +    assert(row.getStruct(0, 4).get(1, FloatType) == null)
    +    assert(row.getStruct(0, 4).get(2, StringType) == null)
    +    assert(row.getStruct(0, 4).getStruct(3, 2) == null)
    +    assert(record == recordFromRow)
    +
    +    record.put(0, 1)
    +    row = expressionEncoder.toRow(record)
    +    recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
    +
    +    assert(row.getStruct(0, 4).get(1, FloatType) == null)
    +    assert(row.getStruct(0, 4).get(2, StringType) == null)
    +    assert(row.getStruct(0, 4).getStruct(3, 2) == null)
    +    assert(record == recordFromRow)
    +
    +    record.put(0, 1F)
    +    row = expressionEncoder.toRow(record)
    +    recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
    +
    +    assert(row.getStruct(0, 4).get(0, IntegerType) == null)
    +    assert(row.getStruct(0, 4).get(2, StringType) == null)
    +    assert(row.getStruct(0, 4).getStruct(3, 2) == null)
    +    assert(record == recordFromRow)
    +
    +    record.put(0, "bar")
    +    row = expressionEncoder.toRow(record)
    +    recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
    +
    +    assert(row.getStruct(0, 4).get(0, IntegerType) == null)
    +    assert(row.getStruct(0, 4).get(1, FloatType) == null)
    +    assert(row.getStruct(0, 4).getStruct(3, 2) == null)
    +    assert(record == recordFromRow)
    +
    +    record.put(0, new GenericRecordBuilder(nested).build())
    +    row = expressionEncoder.toRow(record)
    +    recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
    +
    +    assert(row.getStruct(0, 4).get(0, IntegerType) == null)
    +    assert(row.getStruct(0, 4).get(1, FloatType) == null)
    +    assert(row.getStruct(0, 4).get(2, StringType) == null)
    +    assert(record == recordFromRow)
    +  }
    +
    +  test("create Dataset from GenericRecord") {
    +    // need a spark context with kryo as serializer
    +    val conf = new SparkConf()
    +      .set("spark.serializer", 
"org.apache.spark.serializer.KryoSerializer")
    +      .set("spark.driver.allowMultipleContexts", "true")
    +      .set("spark.master", "local[2]")
    +      .set("spark.app.name", "AvroSuite")
    +    val context = new SparkContext(conf)
    +
    +    val schema: Schema =
    +      SchemaBuilder
    +        .record("GenericRecordTest")
    +        .namespace("com.databricks.spark.avro")
    +        .fields()
    +        .requiredString("field1")
    +        .name("enumVal").`type`().enumeration("letters").symbols("a", "b", 
"c").enumDefault("a")
    +        
.name("fixedVal").`type`().fixed("MD5").size(16).fixedDefault(ByteBuffer.allocate(16))
    +        .endRecord()
    +
    +    implicit val enc = AvroEncoder.of[GenericData.Record](schema)
    +
    +    val genericRecords = (1 to 10) map { i =>
    +      new GenericRecordBuilder(schema)
    +        .set("field1", "field-" + i)
    +        .build()
    +    }
    +
    +    val rdd: RDD[GenericData.Record] = context.parallelize(genericRecords)
    +    val ds = rdd.toDS()
    +    assert(ds.count() == genericRecords.size)
    +    context.stop()
    +  }
     }
    --- End diff --
    
    The above tests are all for `GenericRecord` Avro classes. It might be good 
to generate an Avro class having a schema similar to the GenericRecord 
described 
[above](https://github.com/apache/spark/pull/22878/files#diff-9364b0610f92b3cc35a4bc43a80751bfR1386),
 so that we can test an instance extending `SpecificRecord` (which will 
probably be the most commonly used Avro class for the encoder).
    
    There was one such 
[class](https://github.com/databricks/spark-avro/pull/217/files#diff-6088df231dbde4904100296c8a90fe93)
 in the Spark-Avro project, but I can understand why it may not have been 
copied over in this PR.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to