This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e397a1585d3 [SPARK-42406] Terminate Protobuf recursive fields by dropping the field e397a1585d3 is described below commit e397a1585d3b089f155ccb4359d5525cd012d5da Author: Raghu Angadi <raghu.ang...@databricks.com> AuthorDate: Mon Feb 27 20:44:17 2023 -0800 [SPARK-42406] Terminate Protobuf recursive fields by dropping the field ### What changes were proposed in this pull request? Protobuf deserializer (`from_protobuf()` function()) optionally supports recursive fields up to certain depth. Currently it uses `NullType` to terminate the recursion. But an `ArrayType` containing `NullType` is not really useful and it does not work delta. This PR fixes this by removing the field to terminate recursion rather than using `NullType`. The following example illustrates the difference. E.g. Consider a recursive Protobuf like this: ``` message Node { int value = 1; repeated Node children = 2 // recursive array } message Tree { Node root = 1 } ``` Catalyst schama with `from_protobuf()` of `Tree` with max recursive depth set to 2, would be: - **Before**: _STRUCT<root: STRUCT<value: int, children: array<STRUCT<value: int, **children: array< void >**>>>>_ - **After**: _STRUCT<root: STRUCT<value: int, children: array<STRUCT<value: int>>>>_ Notice that at second level, the `children` array is dropped, rather than being defined as `array<void>`. ### Why are the changes needed? - This improves how Protobuf connector handles recursive fields. It avoids using `void` fields which are problematic in many scenarios and do not add any information. ### Does this PR introduce _any_ user-facing change? - This changes the schema in a subtle manner while using with recursive support enabled. Since this only removes an optional field, it is backward compatible. ### How was this patch tested? - Added multiple unit tests and updated existing one. Most of the changes for this PR are in the tests. Closes #40141 from rangadi/recursive-fields. Authored-by: Raghu Angadi <raghu.ang...@databricks.com> Signed-off-by: Gengliang Wang <gengli...@apache.org> --- .../spark/sql/protobuf/ProtobufDeserializer.scala | 3 - .../sql/protobuf/utils/SchemaConverters.scala | 84 ++- .../test/resources/protobuf/functions_suite.desc | Bin 8836 -> 9648 bytes .../test/resources/protobuf/functions_suite.proto | 29 +- .../sql/protobuf/ProtobufFunctionsSuite.scala | 619 ++++++++------------- 5 files changed, 301 insertions(+), 434 deletions(-) diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala index 37278fab8a3..7723687a4d9 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala @@ -157,9 +157,6 @@ private[sql] class ProtobufDeserializer( (protoType.getJavaType, catalystType) match { case (null, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal) - // It is possible that this will result in data being dropped, This is intentional, - // to catch recursive fields and drop them as necessary. - case (MESSAGE, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal) // TODO: we can avoid boxing if future version of Protobuf provide primitive accessors. case (BOOLEAN, BooleanType) => diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala index 9666e34bab4..e277f2999e4 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala @@ -21,11 +21,12 @@ import scala.collection.JavaConverters._ import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ @DeveloperApi -object SchemaConverters { +object SchemaConverters extends Logging { /** * Internal wrapper for SQL data type and nullability. @@ -59,6 +60,8 @@ object SchemaConverters { // existingRecordNames: Map[String, Int] used to track the depth of recursive fields and to // ensure that the conversion of the protobuf message to a Spark SQL StructType object does not // exceed the maximum recursive depth specified by the recursiveFieldMaxDepth option. + // A return of None implies the field has reached the maximum allowed recursive depth and + // should be dropped. def structFieldFor( fd: FieldDescriptor, existingRecordNames: Map[String, Int], @@ -84,10 +87,10 @@ object SchemaConverters { fd.getMessageType.getFields.size() == 2 && fd.getMessageType.getFields.get(0).getName.equals("seconds") && fd.getMessageType.getFields.get(1).getName.equals("nanos")) => - Some(TimestampType) + Some(TimestampType) case MESSAGE if fd.isRepeated && fd.getMessageType.getOptions.hasMapEntry => - var keyType: DataType = NullType - var valueType: DataType = NullType + var keyType: Option[DataType] = None + var valueType: Option[DataType] = None fd.getMessageType.getFields.forEach { field => field.getName match { case "key" => @@ -95,32 +98,42 @@ object SchemaConverters { structFieldFor( field, existingRecordNames, - protobufOptions).get.dataType + protobufOptions).map(_.dataType) case "value" => valueType = structFieldFor( field, existingRecordNames, - protobufOptions).get.dataType + protobufOptions).map(_.dataType) } } - return Option( - StructField( - fd.getName, - MapType(keyType, valueType, valueContainsNull = false).defaultConcreteType, - nullable = false)) + (keyType, valueType) match { + case (None, _) => + // This is probably never expected. Protobuf does not allow complex types for keys. + log.info(s"Dropping map field ${fd.getFullName}. Key reached max recursive depth.") + None + case (_, None) => + log.info(s"Dropping map field ${fd.getFullName}. Value reached max recursive depth.") + None + case (Some(kt), Some(vt)) => Some(MapType(kt, vt, valueContainsNull = false)) + } case MESSAGE => - // If the `recursive.fields.max.depth` value is not specified, it will default to -1; - // recursive fields are not permitted. Setting it to 0 drops all recursive fields, + // If the `recursive.fields.max.depth` value is not specified, it will default to -1, + // and recursive fields are not permitted. Setting it to 0 drops all recursive fields, // 1 allows it to be recursed once, and 2 allows it to be recursed twice and so on. // A value greater than 10 is not allowed, and if a protobuf record has more depth for // recursive fields than the allowed value, it will be truncated and some fields may be // discarded. - // SQL Schema for the protobuf message `message Person { string name = 1; Person bff = 2}` + // SQL Schema for protob2uf `message Person { string name = 1; Person bff = 2;}` // will vary based on the value of "recursive.fields.max.depth". - // 1: struct<name: string, bff: null> - // 2: struct<name string, bff: <name: string, bff: null>> - // 3: struct<name string, bff: <name: string, bff: struct<name: string, bff: null>>> ... + // 1: struct<name: string> + // 2: struct<name string, bff: struct<name: string>> + // 3: struct<name string, bff: struct<name string, bff: struct<name: string>>> + // and so on. + // TODO(rangadi): A better way to terminate would be replace the remaining recursive struct + // with the byte array of corresponding protobuf. This way no information is lost. + // i.e. with max depth 2, the above looks like this: + // struct<name: string, bff: struct<name: string, _serialized_bff: bytes>> val recordName = fd.getMessageType.getFullName val recursiveDepth = existingRecordNames.getOrElse(recordName, 0) val recursiveFieldMaxDepth = protobufOptions.recursiveFieldMaxDepth @@ -129,23 +142,36 @@ object SchemaConverters { throw QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString()) } else if (existingRecordNames.contains(recordName) && recursiveDepth >= recursiveFieldMaxDepth) { - Some(NullType) + // Recursive depth limit is reached. This field is dropped. + // If it is inside a container like map or array, the containing field is dropped. + log.info( + s"The field ${fd.getFullName} of type $recordName is dropped " + + s"at recursive depth $recursiveDepth" + ) + None } else { val newRecordNames = existingRecordNames + (recordName -> (recursiveDepth + 1)) - Option( - fd.getMessageType.getFields.asScala - .flatMap(structFieldFor(_, newRecordNames, protobufOptions)) - .toSeq) - .filter(_.nonEmpty) - .map(StructType.apply) + val fields = fd.getMessageType.getFields.asScala.flatMap( + structFieldFor(_, newRecordNames, protobufOptions) + ).toSeq + fields match { + case Nil => + log.info( + s"Dropping ${fd.getFullName} as it does not have any fields left " + + "likely due to recursive depth limit." + ) + None + case fds => Some(StructType(fds)) + } } case other => throw QueryCompilationErrors.protobufTypeUnsupportedYetError(other.toString) } - dataType.map(dt => - StructField( - fd.getName, - if (fd.isRepeated) ArrayType(dt, containsNull = false) else dt, - nullable = !fd.isRequired && !fd.isRepeated)) + dataType.map { + case dt: MapType => StructField(fd.getName, dt) + case dt if fd.isRepeated => + StructField(fd.getName, ArrayType(dt, containsNull = false)) + case dt => StructField(fd.getName, dt, nullable = !fd.isRequired) + } } } diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc index d16f8935080..467b9cac969 100644 Binary files a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc and b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc differ diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto index a0698ee3979..d83ba6a4f6e 100644 --- a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto +++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto @@ -233,6 +233,19 @@ message EventPersonWrapper { EventPerson person = 1; } +message PersonWithRecursiveArray { + // A protobuf with recursive repeated field + string name = 1; + repeated PersonWithRecursiveArray friends = 2; +} + +message PersonWithRecursiveMap { + // A protobuf with recursive field in value + string name = 1; + map<string, PersonWithRecursiveMap> groups = 3; +} + + message OneOfEventWithRecursion { string key = 1; oneof payload { @@ -243,14 +256,26 @@ message OneOfEventWithRecursion { } message EventRecursiveA { - OneOfEventWithRecursion recursiveA = 1; + OneOfEventWithRecursion recursiveOneOffInA = 1; string key = 2; } message EventRecursiveB { string key = 1; string value = 2; - OneOfEventWithRecursion recursiveA = 3; + OneOfEventWithRecursion recursiveOneOffInB = 3; +} + +message EmptyRecursiveProto { + // This is a recursive proto with no fields. Used to test edge. Catalyst schema for this + // should be "nothing" (i.e. completely dropped) irrespective of recursive limit. + EmptyRecursiveProto recursive_field = 1; + repeated EmptyRecursiveProto recursive_array = 2; +} + +message EmptyRecursiveProtoWrapper { + string name = 1; + EmptyRecursiveProto empty_recursive = 2; // This field will be dropped. } message Status { diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 60e13644fc6..92c3c27bfae 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.protobuf import java.sql.Timestamp import java.time.Duration -import scala.collection.JavaConverters._ + import scala.collection.JavaConverters._ import com.google.protobuf.{ByteString, DynamicMessage} import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row} import org.apache.spark.sql.functions.{lit, struct} -import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{EM, EM2, Employee, EventPerson, EventPersonWrapper, EventRecursiveA, EventRecursiveB, IC, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated} +import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos._ import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum import org.apache.spark.sql.protobuf.utils.ProtobufUtils import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, IntegerType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types._ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with ProtobufTestBase with Serializable { @@ -56,10 +56,15 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot // A wrapper to invoke the right variable of from_protobuf() depending on arguments. private def from_protobuf_wrapper( - col: Column, messageName: String, descFilePathOpt: Option[String]): Column = { + col: Column, + messageName: String, + descFilePathOpt: Option[String], + options: Map[String, String] = Map.empty): Column = { descFilePathOpt match { - case Some(descFilePath) => functions.from_protobuf(col, messageName, descFilePath) - case None => functions.from_protobuf(col, messageName) + case Some(descFilePath) => functions.from_protobuf( + col, messageName, descFilePath, options.asJava + ) + case None => functions.from_protobuf(col, messageName, options.asJava) } } @@ -72,7 +77,6 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } } - test("roundtrip in to_protobuf and from_protobuf - struct") { val df = spark .range(1, 10) @@ -352,7 +356,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } } - test("roundtrip in from_protobuf and to_protobuf - Multiple Message") { + test("round trip in from_protobuf and to_protobuf - Multiple Message") { val messageMultiDesc = ProtobufUtils.buildDescriptor(testFileDesc, "MultipleExample") val messageIncludeDesc = ProtobufUtils.buildDescriptor(testFileDesc, "IncludedExample") val messageOtherDesc = ProtobufUtils.buildDescriptor(testFileDesc, "OtherExample") @@ -385,10 +389,9 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) } - } - test("Recursive fields in Protobuf should result in an error (B -> A -> B)") { - checkWithFileAndClassName("recursiveB") { + // Simple recursion + checkWithFileAndClassName("recursiveB") { // B -> A -> B case (name, descFilePathOpt) => val e = intercept[AnalysisException] { emptyBinaryDF.select( @@ -702,92 +705,44 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot assert(expectedFields.contains(f.getName)) }) - val jsonSchema = - s""" - |{ - | "type" : "struct", - | "fields" : [ { - | "name" : "sample", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "key", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "col_1", - | "type" : "integer", - | "nullable" : true - | }, { - | "name" : "col_2", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "col_3", - | "type" : "long", - | "nullable" : true - | }, { - | "name" : "col_4", - | "type" : { - | "type" : "array", - | "elementType" : "string", - | "containsNull" : false - | }, - | "nullable" : false - | } ] - | }, - | "nullable" : true - | } ] - |} - |{ - | "type" : "struct", - | "fields" : [ { - | "name" : "sample", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "key", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "col_1", - | "type" : "integer", - | "nullable" : true - | }, { - | "name" : "col_2", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "col_3", - | "type" : "long", - | "nullable" : true - | }, { - | "name" : "col_4", - | "type" : { - | "type" : "array", - | "elementType" : "string", - | "containsNull" : false - | }, - | "nullable" : false - | } ] - | }, - | "nullable" : true - | } ] - |} - |""".stripMargin - val schema = DataType.fromJson(jsonSchema).asInstanceOf[StructType] - val data = Seq(Row(Row("key", 123, "col2value", 109202L, Seq("col4value")))) + val schema = DataType.fromJson( + """ + | { + | "type":"struct", + | "fields":[ + | {"name":"sample","nullable":true,"type":{ + | "type":"struct", + | "fields":[ + | {"name":"key","type":"string","nullable":true}, + | {"name":"col_1","type":"integer","nullable":true}, + | {"name":"col_2","type":"string","nullable":true}, + | {"name":"col_3","type":"long","nullable":true}, + | {"name":"col_4","nullable":true,"type":{ + | "type":"array","elementType":"string","containsNull":false}} + | ]} + | } + | ] + | } + |""".stripMargin).asInstanceOf[StructType] + assert(fromProtoDf.schema == schema) + + val data = Seq( + Row(Row("key", 123, "col2value", 109202L, Seq("col4value"))), + Row(Row("key2", null, null, null, null)) // Leave the rest null, including "col_4" array. + ) val dataDf = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) val dataDfToProto = dataDf.select( to_protobuf_wrapper($"sample", name, descFilePathOpt) as 'toProto) - val eventFromSparkSchema = OneOfEvent.parseFrom( - dataDfToProto.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0)) + val toProtoResults = dataDfToProto.select("toProto").collect() + val eventFromSparkSchema = OneOfEvent.parseFrom(toProtoResults(0).getAs[Array[Byte]](0)) assert(eventFromSparkSchema.getCol2.isEmpty) assert(eventFromSparkSchema.getCol3 == 109202L) eventFromSparkSchema.getDescriptorForType.getFields.asScala.map(f => { assert(expectedFields.contains(f.getName)) }) + val secondEventFromSpark = OneOfEvent.parseFrom(toProtoResults(1).getAs[Array[Byte]](0)) + assert(secondEventFromSpark.getKey == "key2") } } @@ -853,13 +808,13 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot .setKey("keyNested2").setValue("valueNested2").build() val nestedOne = EventRecursiveA.newBuilder() .setKey("keyNested1") - .setRecursiveA(nestedTwo).build() + .setRecursiveOneOffInA(nestedTwo).build() val oneOfRecursionEvent = OneOfEventWithRecursion.newBuilder() .setKey("keyNested0") .setValue("valueNested0") .setRecursiveA(nestedOne).build() val recursiveA = EventRecursiveA.newBuilder().setKey("recursiveAKey") - .setRecursiveA(oneOfRecursionEvent).build() + .setRecursiveOneOffInA(oneOfRecursionEvent).build() val recursiveB = EventRecursiveB.newBuilder() .setKey("recursiveBKey") .setValue("recursiveBvalue").build() @@ -872,7 +827,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot val df = Seq(oneOfEventWithRecursion.toByteArray).toDF("value") val options = new java.util.HashMap[String, String]() - options.put("recursive.fields.max.depth", "2") + options.put("recursive.fields.max.depth", "2") // Recursive fields appear twice. val fromProtoDf = df.select( functions.from_protobuf($"value", @@ -896,177 +851,60 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot val eventFromSpark = OneOfEventWithRecursion.parseFrom( toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0)) - var recursiveField = eventFromSpark.getRecursiveA.getRecursiveA + var recursiveField = eventFromSpark.getRecursiveA.getRecursiveOneOffInA assert(recursiveField.getKey.equals("keyNested0")) assert(recursiveField.getValue.equals("valueNested0")) assert(recursiveField.getRecursiveA.getKey.equals("keyNested1")) - assert(recursiveField.getRecursiveA.getRecursiveA.getKey.isEmpty()) + assert(recursiveField.getRecursiveA.getRecursiveOneOffInA.getKey.isEmpty()) val expectedFields = descriptor.getFields.asScala.map(f => f.getName) eventFromSpark.getDescriptorForType.getFields.asScala.map(f => { assert(expectedFields.contains(f.getName)) }) - val jsonSchema = - s""" - |{ - | "type" : "struct", - | "fields" : [ { - | "name" : "sample", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "key", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "recursiveA", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "recursiveA", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "key", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "recursiveA", - | "type" : "void", - | "nullable" : true - | }, { - | "name" : "recursiveB", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "key", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "value", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "recursiveA", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "key", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "recursiveA", - | "type" : "void", - | "nullable" : true - | }, { - | "name" : "recursiveB", - | "type" : "void", - | "nullable" : true - | }, { - | "name" : "value", - | "type" : "string", - | "nullable" : true - | } ] - | }, - | "nullable" : true - | } ] - | }, - | "nullable" : true - | }, { - | "name" : "value", - | "type" : "string", - | "nullable" : true - | } ] - | }, - | "nullable" : true - | }, { - | "name" : "key", - | "type" : "string", - | "nullable" : true - | } ] - | }, - | "nullable" : true - | }, { - | "name" : "recursiveB", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "key", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "value", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "recursiveA", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "key", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "recursiveA", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "recursiveA", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "key", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "recursiveA", - | "type" : "void", - | "nullable" : true - | }, { - | "name" : "recursiveB", - | "type" : "void", - | "nullable" : true - | }, { - | "name" : "value", - | "type" : "string", - | "nullable" : true - | } ] - | }, - | "nullable" : true - | }, { - | "name" : "key", - | "type" : "string", - | "nullable" : true - | } ] - | }, - | "nullable" : true - | }, { - | "name" : "recursiveB", - | "type" : "void", - | "nullable" : true - | }, { - | "name" : "value", - | "type" : "string", - | "nullable" : true - | } ] - | }, - | "nullable" : true - | } ] - | }, - | "nullable" : true - | }, { - | "name" : "value", - | "type" : "string", - | "nullable" : true - | } ] - | }, - | "nullable" : true - | } ] - |} - |""".stripMargin - - val schema = DataType.fromJson(jsonSchema).asInstanceOf[StructType] + val schemaDDL = + """ + | -- OneOfEvenWithRecursion with max depth 2. + | sample STRUCT< -- 1st level for OneOffWithRecursion + | key string, + | recursiveA STRUCT< -- 1st level for RecursiveA + | recursiveOneOffInA STRUCT< -- 2st level for OneOffWithRecursion + | key string, + | recursiveA STRUCT< -- 2st level for RecursiveA + | key string + | -- Removed recursiveOneOffInA: 3rd level for OneOffWithRecursion + | >, + | recursiveB STRUCT< + | key string, + | value string + | -- Removed recursiveOneOffInB: 3rd level for OneOffWithRecursion + | >, + | value string + | >, + | key string + | >, + | recursiveB STRUCT< -- 1st level for RecursiveB + | key string, + | value string, + | recursiveOneOffInB STRUCT< -- 2st level for OneOffWithRecursion + | key string, + | recursiveA STRUCT< -- 1st level for RecursiveA + | key string + | -- Removed recursiveOneOffInA: 3rd level for OneOffWithRecursion + | >, + | recursiveB STRUCT< + | key string, + | value string + | -- Removed recursiveOneOffInB: 3rd level for OneOffWithRecursion + | >, + | value string + | > + | >, + | value string + | > + |""".stripMargin + val schema = structFromDDL(schemaDDL) + assert(fromProtoDf.schema == schema) val data = Seq( Row( Row("key1", @@ -1083,7 +921,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot val eventFromSparkSchema = OneOfEventWithRecursion.parseFrom( dataDfToProto.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0)) - recursiveField = eventFromSparkSchema.getRecursiveA.getRecursiveA + recursiveField = eventFromSparkSchema.getRecursiveA.getRecursiveOneOffInA assert(recursiveField.getKey.equals("keyNested0")) assert(recursiveField.getValue.equals("valueNested0")) assert(recursiveField.getRecursiveA.getKey.isEmpty()) @@ -1101,166 +939,61 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot val optionsZero = new java.util.HashMap[String, String]() optionsZero.put("recursive.fields.max.depth", "1") - val schemaZero = DataType.fromJson( - s"""{ - | "type" : "struct", - | "fields" : [ { - | "name" : "sample", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "name", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "bff", - | "type" : "void", - | "nullable" : true - | } ] - | }, - | "nullable" : true - | } ] - |}""".stripMargin).asInstanceOf[StructType] - val expectedDfZero = spark.createDataFrame( - spark.sparkContext.parallelize(Seq(Row(Row("person0", null)))), schemaZero) - - testFromProtobufWithOptions(df, expectedDfZero, optionsZero, "EventPerson") - - val optionsOne = new java.util.HashMap[String, String]() - optionsOne.put("recursive.fields.max.depth", "2") - val schemaOne = DataType.fromJson( - s"""{ - | "type" : "struct", - | "fields" : [ { - | "name" : "sample", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "name", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "bff", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "name", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "bff", - | "type" : "void", - | "nullable" : true - | } ] - | }, - | "nullable" : true - | } ] - | }, - | "nullable" : true - | } ] - |}""".stripMargin).asInstanceOf[StructType] + val schemaOne = structFromDDL( + "sample STRUCT<name: STRING>" // 'bff' field is dropped to due to limit of 1. + ) val expectedDfOne = spark.createDataFrame( - spark.sparkContext.parallelize(Seq(Row(Row("person0", Row("person1", null))))), schemaOne) - testFromProtobufWithOptions(df, expectedDfOne, optionsOne, "EventPerson") + spark.sparkContext.parallelize(Seq(Row(Row("person0", null)))), schemaOne) + testFromProtobufWithOptions(df, expectedDfOne, optionsZero, "EventPerson") val optionsTwo = new java.util.HashMap[String, String]() - optionsTwo.put("recursive.fields.max.depth", "3") - val schemaTwo = DataType.fromJson( - s"""{ - | "type" : "struct", - | "fields" : [ { - | "name" : "sample", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "name", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "bff", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "name", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "bff", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "name", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "bff", - | "type" : "void", - | "nullable" : true - | } ] - | }, - | "nullable" : true - | } ] - | }, - | "nullable" : true - | } ] - | }, - | "nullable" : true - | } ] - |}""".stripMargin).asInstanceOf[StructType] - val expectedDfTwo = spark.createDataFrame(spark.sparkContext.parallelize( - Seq(Row(Row("person0", Row("person1", Row("person2", null)))))), schemaTwo) + optionsTwo.put("recursive.fields.max.depth", "2") + val schemaTwo = structFromDDL( + """ + | sample STRUCT< + | name: STRING, + | bff: STRUCT<name: STRING> -- Recursion is terminated here. + | > + |""".stripMargin) + val expectedDfTwo = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(Row("person0", Row("person1", null))))), schemaTwo) testFromProtobufWithOptions(df, expectedDfTwo, optionsTwo, "EventPerson") + val optionsThree = new java.util.HashMap[String, String]() + optionsThree.put("recursive.fields.max.depth", "3") + val schemaThree = structFromDDL( + """ + | sample STRUCT< + | name: STRING, + | bff: STRUCT< + | name: STRING, + | bff: STRUCT<name: STRING> + | > + | > + |""".stripMargin) + val expectedDfThree = spark.createDataFrame(spark.sparkContext.parallelize( + Seq(Row(Row("person0", Row("person1", Row("person2", null)))))), schemaThree) + testFromProtobufWithOptions(df, expectedDfThree, optionsThree, "EventPerson") + // Test recursive level 1 with EventPersonWrapper. In this case the top level struct // 'EventPersonWrapper' itself does not recurse unlike 'EventPerson'. // "bff" appears twice: Once allowed recursion and second time as terminated "null" type. - val wrapperSchemaOne = DataType.fromJson( + val wrapperSchemaOne = structFromDDL( """ - |{ - | "type" : "struct", - | "fields" : [ { - | "name" : "sample", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "person", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "name", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "bff", - | "type" : { - | "type" : "struct", - | "fields" : [ { - | "name" : "name", - | "type" : "string", - | "nullable" : true - | }, { - | "name" : "bff", - | "type" : "void", - | "nullable" : true - | } ] - | }, - | "nullable" : true - | } ] - | }, - | "nullable" : true - | } ] - | }, - | "nullable" : true - | } ] - |} + | sample STRUCT< + | person: STRUCT< -- 1st level + | name: STRING, + | bff: STRUCT<name: STRING> -- 2nd level. Inner 3rd level Person is dropped. + | > + | > |""".stripMargin).asInstanceOf[StructType] - val expectedWrapperDfOne = spark.createDataFrame( + val expectedWrapperDfTwo = spark.createDataFrame( spark.sparkContext.parallelize(Seq(Row(Row(Row("person0", Row("person1", null)))))), wrapperSchemaOne) testFromProtobufWithOptions( Seq(EventPersonWrapper.newBuilder().setPerson(eventPerson0).build().toByteArray).toDF(), - expectedWrapperDfOne, - optionsOne, + expectedWrapperDfTwo, + optionsTwo, "EventPersonWrapper" ) } @@ -1287,6 +1020,92 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot assert(ex.getCause.getMessage.matches(".*No such file.*"), ex.getCause.getMessage()) } + test("Recursive fields in arrays and maps") { + // Verifies schema for recursive proto in an array field & map field. + val options = Map("recursive.fields.max.depth" -> "3") + + checkWithFileAndClassName("PersonWithRecursiveArray") { + case (name, descFilePathOpt) => + val expectedSchema = StructType( + // DDL: "proto STRUCT<name: string, friends: array< + // struct<name: string, friends: array<struct<name: string>>>>>" + // Can not use DataType.fromDDL(), it does not support "containsNull: false" for arrays. + StructField("proto", + StructType( // 1st level + StructField("name", StringType) :: StructField("friends", // 2nd level + ArrayType( + StructType(StructField("name", StringType) :: StructField("friends", // 3rd level + ArrayType( + StructType(StructField("name", StringType) :: Nil), // 4th, array dropped + containsNull = false) + ):: Nil), + containsNull = false) + ) :: Nil + ) + ) :: Nil + ) + + val df = emptyBinaryDF.select( + from_protobuf_wrapper($"binary", name, descFilePathOpt, options).as("proto") + ) + assert(df.schema == expectedSchema) + } + + checkWithFileAndClassName("PersonWithRecursiveMap") { + case (name, descFilePathOpt) => + val expectedSchema = StructType( + // DDL: "proto STRUCT<name: string, groups: map< + // struct<name: string, group: map<struct<name: string>>>>>" + StructField("proto", + StructType( // 1st level + StructField("name", StringType) :: StructField("groups", // 2nd level + MapType( + StringType, + StructType(StructField("name", StringType) :: StructField("groups", // 3rd level + MapType( + StringType, + StructType(StructField("name", StringType) :: Nil), // 4th, array dropped + valueContainsNull = false) + ):: Nil), + valueContainsNull = false) + ) :: Nil + ) + ) :: Nil + ) + + val df = emptyBinaryDF.select( + from_protobuf_wrapper($"binary", name, descFilePathOpt, options).as("proto") + ) + assert(df.schema == expectedSchema) + } + } + + test("Corner case: empty recursive proto fields should be dropped") { + // This verifies that a empty proto like 'message A { A a = 1}' are completely dropped + // irrespective of max depth setting. + + val options = Map("recursive.fields.max.depth" -> "4") + + // EmptyRecursiveProto at the top level. It will be an empty struct. + checkWithFileAndClassName("EmptyRecursiveProto") { + case (name, descFilePathOpt) => + val df = emptyBinaryDF.select( + from_protobuf_wrapper($"binary", name, descFilePathOpt, options).as("empty_proto") + ) + assert(df.schema == structFromDDL("empty_proto struct<>")) + } + + // EmptyRecursiveProto at inner level. + checkWithFileAndClassName("EmptyRecursiveProtoWrapper") { + case (name, descFilePathOpt) => + val df = emptyBinaryDF.select( + from_protobuf_wrapper($"binary", name, descFilePathOpt, options).as("wrapper") + ) + // 'empty_recursive' field is dropped from the schema. Only "name" is present. + assert(df.schema == structFromDDL("wrapper struct<name: string>")) + } + } + def testFromProtobufWithOptions( df: DataFrame, expectedDf: DataFrame, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org