This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d33a59c940f [SPARK-41396][SQL][PROTOBUF] OneOf field support and recursion checks d33a59c940f is described below commit d33a59c940f0e8f0b93d91cc9e700c2cb533d54e Author: SandishKumarHN <sanysand...@gmail.com> AuthorDate: Wed Dec 21 09:37:15 2022 +0800 [SPARK-41396][SQL][PROTOBUF] OneOf field support and recursion checks Oneof fields allow a message to contain one and only one of a defined set of field types, while recursive fields provide a way to define messages that can refer to themselves, allowing for the creation of complex and nested data structures. with this change users will be able to use protobuf OneOf fields with spark-protobuf, making it a more complete and useful tool for processing protobuf data. **Support for circularReferenceDepth:** The `recursive.fields.max.depth` parameter can be specified in the from_protobuf options to control the maximum allowed recursion depth for a field. Setting `recursive.fields.max.depth` to 0 drops all-recursive fields, setting it to 1 allows it to be recursed once, and setting it to 2 allows it to be recursed twice. Attempting to set the `recursive.fields.max.depth` to a value greater than 10 is not allowed. If the `recursive.fields.max.depth` is not specified, it will default to -1; [...] SQL Schema for the protobuf message ``` message Person { string name = 1; Person bff = 2 } ``` will vary based on the value of `recursive.fields.max.depth`. ``` 0: struct<name: string, bff: null> 1: struct<name string, bff: <name: string, bff: null>> 2: struct<name string, bff: <name: string, bff: struct<name: string, bff: null>>> ... ``` ### What changes were proposed in this pull request? - Add support for protobuf oneof field - Stop recursion at the first level when a recursive field is encountered. (instead of throwing an error) ### Why are the changes needed? Stop recursion at the first level and handle nulltype in deserilization. ### Does this PR introduce _any_ user-facing change? NA ### How was this patch tested? Added Unit tests for OneOf field support and recursion checks. Tested full support for nested OneOf fields and message types using real data from Kafka on a real cluster cc: rangadi mposdev21 Closes #38922 from SandishKumarHN/SPARK-41396. Authored-by: SandishKumarHN <sanysand...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/protobuf/ProtobufDataToCatalyst.scala | 2 +- .../spark/sql/protobuf/ProtobufDeserializer.scala | 8 +- .../spark/sql/protobuf/utils/ProtobufOptions.scala | 8 + .../sql/protobuf/utils/SchemaConverters.scala | 69 ++- .../test/resources/protobuf/functions_suite.desc | Bin 6678 -> 8739 bytes .../test/resources/protobuf/functions_suite.proto | 85 ++- .../sql/protobuf/ProtobufFunctionsSuite.scala | 576 ++++++++++++++++++++- core/src/main/resources/error/error-classes.json | 2 +- 8 files changed, 721 insertions(+), 29 deletions(-) diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala index c0997b1bd06..da44f94d5ea 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala @@ -39,7 +39,7 @@ private[protobuf] case class ProtobufDataToCatalyst( override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) override lazy val dataType: DataType = { - val dt = SchemaConverters.toSqlType(messageDescriptor).dataType + val dt = SchemaConverters.toSqlType(messageDescriptor, protobufOptions).dataType parseMode match { // With PermissiveMode, the output Catalyst row might contain columns of null values for // corrupt records, even if some of the columns are not nullable in the user-provided schema. diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala index 46366ba268b..224e22c0f52 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala @@ -156,6 +156,9 @@ private[sql] class ProtobufDeserializer( (protoType.getJavaType, catalystType) match { case (null, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal) + // It is possible that this will result in data being dropped, This is intentional, + // to catch recursive fields and drop them as necessary. + case (MESSAGE, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal) // TODO: we can avoid boxing if future version of Protobuf provide primitive accessors. case (BOOLEAN, BooleanType) => @@ -171,7 +174,7 @@ private[sql] class ProtobufDeserializer( (updater, ordinal, value) => updater.setShort(ordinal, value.asInstanceOf[Short]) case ( - BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING, + MESSAGE | BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING, ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated => newArrayWriter(protoType, protoPath, catalystPath, dataType, containsNull) @@ -235,9 +238,6 @@ private[sql] class ProtobufDeserializer( writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage]) updater.set(ordinal, row) - case (MESSAGE, ArrayType(st: StructType, containsNull)) => - newArrayWriter(protoType, protoPath, catalystPath, st, containsNull) - case (ENUM, StringType) => (updater, ordinal, value) => updater.set(ordinal, UTF8String.fromString(value.toString)) diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala index 1cece0d7966..52f9f74bd43 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala @@ -38,6 +38,14 @@ private[sql] class ProtobufOptions( val parseMode: ParseMode = parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode) + + // Setting the `recursive.fields.max.depth` to 0 drops all recursive fields, + // 1 allows it to be recurse once, and 2 allows it to be recursed twice and so on. + // A value of `recursive.fields.max.depth` greater than 10 is not permitted. If it is not + // specified, the default value is -1; recursive fields are not permitted. If a protobuf + // record has more depth than the allowed value for recursive fields, it will be truncated + // and some fields may be discarded. + val recursiveFieldMaxDepth: Int = parameters.getOrElse("recursive.fields.max.depth", "-1").toInt } private[sql] object ProtobufOptions { diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala index 4979fb9a504..8d321c13a56 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala @@ -40,19 +40,30 @@ object SchemaConverters { * * @since 3.4.0 */ - def toSqlType(descriptor: Descriptor): SchemaType = { - toSqlTypeHelper(descriptor) + def toSqlType( + descriptor: Descriptor, + protobufOptions: ProtobufOptions = ProtobufOptions(Map.empty)): SchemaType = { + toSqlTypeHelper(descriptor, protobufOptions) } - def toSqlTypeHelper(descriptor: Descriptor): SchemaType = ScalaReflectionLock.synchronized { + def toSqlTypeHelper( + descriptor: Descriptor, + protobufOptions: ProtobufOptions): SchemaType = ScalaReflectionLock.synchronized { SchemaType( - StructType(descriptor.getFields.asScala.flatMap(structFieldFor(_, Set.empty)).toArray), + StructType(descriptor.getFields.asScala.flatMap( + structFieldFor(_, + Map(descriptor.getFullName -> 1), + protobufOptions: ProtobufOptions)).toArray), nullable = true) } + // existingRecordNames: Map[String, Int] used to track the depth of recursive fields and to + // ensure that the conversion of the protobuf message to a Spark SQL StructType object does not + // exceed the maximum recursive depth specified by the recursiveFieldMaxDepth option. def structFieldFor( fd: FieldDescriptor, - existingRecordNames: Set[String]): Option[StructField] = { + existingRecordNames: Map[String, Int], + protobufOptions: ProtobufOptions): Option[StructField] = { import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ val dataType = fd.getJavaType match { case INT => Some(IntegerType) @@ -81,9 +92,17 @@ object SchemaConverters { fd.getMessageType.getFields.forEach { field => field.getName match { case "key" => - keyType = structFieldFor(field, existingRecordNames).get.dataType + keyType = + structFieldFor( + field, + existingRecordNames, + protobufOptions).get.dataType case "value" => - valueType = structFieldFor(field, existingRecordNames).get.dataType + valueType = + structFieldFor( + field, + existingRecordNames, + protobufOptions).get.dataType } } return Option( @@ -92,17 +111,35 @@ object SchemaConverters { MapType(keyType, valueType, valueContainsNull = false).defaultConcreteType, nullable = false)) case MESSAGE => - if (existingRecordNames.contains(fd.getFullName)) { + // If the `recursive.fields.max.depth` value is not specified, it will default to -1; + // recursive fields are not permitted. Setting it to 0 drops all recursive fields, + // 1 allows it to be recursed once, and 2 allows it to be recursed twice and so on. + // A value greater than 10 is not allowed, and if a protobuf record has more depth for + // recursive fields than the allowed value, it will be truncated and some fields may be + // discarded. + // SQL Schema for the protobuf message `message Person { string name = 1; Person bff = 2}` + // will vary based on the value of "recursive.fields.max.depth". + // 0: struct<name: string, bff: null> + // 1: struct<name string, bff: <name: string, bff: null>> + // 2: struct<name string, bff: <name: string, bff: struct<name: string, bff: null>>> ... + val recordName = fd.getMessageType.getFullName + val recursiveDepth = existingRecordNames.getOrElse(recordName, 0) + val recursiveFieldMaxDepth = protobufOptions.recursiveFieldMaxDepth + if (existingRecordNames.contains(recordName) && (recursiveFieldMaxDepth < 0 || + recursiveFieldMaxDepth > 10)) { throw QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString()) + } else if (existingRecordNames.contains(recordName) && + recursiveDepth > recursiveFieldMaxDepth) { + Some(NullType) + } else { + val newRecordNames = existingRecordNames + (recordName -> (recursiveDepth + 1)) + Option( + fd.getMessageType.getFields.asScala + .flatMap(structFieldFor(_, newRecordNames, protobufOptions)) + .toSeq) + .filter(_.nonEmpty) + .map(StructType.apply) } - val newRecordNames = existingRecordNames + fd.getFullName - - Option( - fd.getMessageType.getFields.asScala - .flatMap(structFieldFor(_, newRecordNames)) - .toSeq) - .filter(_.nonEmpty) - .map(StructType.apply) case other => throw QueryCompilationErrors.protobufTypeUnsupportedYetError(other.toString) } diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc index d54ee4337a5..135d489f520 100644 Binary files a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc and b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc differ diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto index 2fef8495c5e..449f1b68bb8 100644 --- a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto +++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto @@ -170,4 +170,87 @@ message timeStampMsg { message durationMsg { string key = 1; Duration duration = 2; -} \ No newline at end of file +} + +message OneOfEvent { + string key = 1; + oneof payload { + int32 col_1 = 2; + string col_2 = 3; + int64 col_3 = 4; + } + repeated string col_4 = 5; +} + +message EventWithRecursion { + int32 key = 1; + messageA a = 2; +} +message messageA { + EventWithRecursion a = 1; + messageB b = 2; +} +message messageB { + EventWithRecursion aa = 1; + messageC c = 2; +} +message messageC { + EventWithRecursion aaa = 1; + int32 key= 2; +} + +message Employee { + string firstName = 1; + string lastName = 2; + oneof role { + IC ic = 3; + EM em = 4; + EM2 em2 = 5; + } +} + +message IC { + repeated string skills = 1; + Employee icManager = 2; +} + +message EM { + int64 teamsize = 1; + Employee emManager = 2; +} + +message EM2 { + int64 teamsize = 1; + Employee em2Manager = 2; +} + +message EventPerson { + string name = 1; + EventPerson bff = 2; +} + +message OneOfEventWithRecursion { + string key = 1; + oneof payload { + EventRecursiveA recursiveA = 3; + EventRecursiveB recursiveB = 6; + } + string value = 7; +} + +message EventRecursiveA { + OneOfEventWithRecursion recursiveA = 1; + string key = 2; +} + +message EventRecursiveB { + string key = 1; + string value = 2; + OneOfEventWithRecursion recursiveA = 3; +} + +message Status { + int32 id = 1; + Timestamp trade_time = 2; + Status status = 3; +} diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index e493bc66ca7..79d7f96414b 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -23,14 +23,13 @@ import scala.collection.JavaConverters._ import com.google.protobuf.{ByteString, DynamicMessage} -import org.apache.spark.sql.{Column, QueryTest, Row} -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row} import org.apache.spark.sql.functions.{lit, struct} -import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated +import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{messageA, messageB, messageC, EM, EM2, Employee, EventPerson, EventRecursiveA, EventRecursiveB, EventWithRecursion, IC, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated} import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum import org.apache.spark.sql.protobuf.utils.ProtobufUtils import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{DayTimeIntervalType, IntegerType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, IntegerType, StringType, StructField, StructType, TimestampType} class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with ProtobufTestBase with Serializable { @@ -417,7 +416,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot .show() } assert(e.getMessage.contains( - "Found recursive reference in Protobuf schema, which can not be processed by Spark:" + "Found recursive reference in Protobuf schema, which can not be processed by Spark" )) } } @@ -453,7 +452,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot .show() } assert(e.getMessage.contains( - "Found recursive reference in Protobuf schema, which can not be processed by Spark:" + "Found recursive reference in Protobuf schema, which can not be processed by Spark" )) } } @@ -693,4 +692,569 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot errorClass = "CANNOT_CONSTRUCT_PROTOBUF_DESCRIPTOR", parameters = Map("descFilePath" -> testFileDescriptor)) } + + test("Verify OneOf field between from_protobuf -> to_protobuf and struct -> from_protobuf") { + val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "OneOfEvent") + val oneOfEvent = OneOfEvent.newBuilder() + .setKey("key") + .setCol1(123) + .setCol3(109202L) + .setCol2("col2value") + .addCol4("col4value").build() + + val df = Seq(oneOfEvent.toByteArray).toDF("value") + + checkWithFileAndClassName("OneOfEvent") { + case (name, descFilePathOpt) => + val fromProtoDf = df.select( + from_protobuf_wrapper($"value", name, descFilePathOpt) as 'sample) + val toDf = fromProtoDf.select( + to_protobuf_wrapper($"sample", name, descFilePathOpt) as 'toProto) + val toFromDf = toDf.select( + from_protobuf_wrapper($"toProto", name, descFilePathOpt) as 'fromToProto) + checkAnswer(fromProtoDf, toFromDf) + val actualFieldNames = fromProtoDf.select("sample.*").schema.fields.toSeq.map(f => f.name) + descriptor.getFields.asScala.map(f => { + assert(actualFieldNames.contains(f.getName)) + }) + + val eventFromSpark = OneOfEvent.parseFrom( + toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0)) + // OneOf field: the last set value(by order) will overwrite all previous ones. + assert(eventFromSpark.getCol2.equals("col2value")) + assert(eventFromSpark.getCol3 == 0) + val expectedFields = descriptor.getFields.asScala.map(f => f.getName) + eventFromSpark.getDescriptorForType.getFields.asScala.map(f => { + assert(expectedFields.contains(f.getName)) + }) + + val jsonSchema = + s""" + |{ + | "type" : "struct", + | "fields" : [ { + | "name" : "sample", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "key", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "col_1", + | "type" : "integer", + | "nullable" : true + | }, { + | "name" : "col_2", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "col_3", + | "type" : "long", + | "nullable" : true + | }, { + | "name" : "col_4", + | "type" : { + | "type" : "array", + | "elementType" : "string", + | "containsNull" : false + | }, + | "nullable" : false + | } ] + | }, + | "nullable" : true + | } ] + |} + |{ + | "type" : "struct", + | "fields" : [ { + | "name" : "sample", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "key", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "col_1", + | "type" : "integer", + | "nullable" : true + | }, { + | "name" : "col_2", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "col_3", + | "type" : "long", + | "nullable" : true + | }, { + | "name" : "col_4", + | "type" : { + | "type" : "array", + | "elementType" : "string", + | "containsNull" : false + | }, + | "nullable" : false + | } ] + | }, + | "nullable" : true + | } ] + |} + |""".stripMargin + val schema = DataType.fromJson(jsonSchema).asInstanceOf[StructType] + val data = Seq(Row(Row("key", 123, "col2value", 109202L, Seq("col4value")))) + val dataDf = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + val dataDfToProto = dataDf.select( + to_protobuf_wrapper($"sample", name, descFilePathOpt) as 'toProto) + + val eventFromSparkSchema = OneOfEvent.parseFrom( + dataDfToProto.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0)) + assert(eventFromSparkSchema.getCol2.isEmpty) + assert(eventFromSparkSchema.getCol3 == 109202L) + eventFromSparkSchema.getDescriptorForType.getFields.asScala.map(f => { + assert(expectedFields.contains(f.getName)) + }) + } + } + + test("Fail for recursion field with complex schema without recursive.fields.max.depth") { + val aEventWithRecursion = EventWithRecursion.newBuilder().setKey(2).build() + val aaEventWithRecursion = EventWithRecursion.newBuilder().setKey(3).build() + val aaaEventWithRecursion = EventWithRecursion.newBuilder().setKey(4).build() + val c = messageC.newBuilder().setAaa(aaaEventWithRecursion).setKey(12092) + val b = messageB.newBuilder().setAa(aaEventWithRecursion).setC(c) + val a = messageA.newBuilder().setA(aEventWithRecursion).setB(b).build() + val eventWithRecursion = EventWithRecursion.newBuilder().setKey(1).setA(a).build() + + val df = Seq(eventWithRecursion.toByteArray).toDF("protoEvent") + + checkWithFileAndClassName("EventWithRecursion") { + case (name, descFilePathOpt) => + val e = intercept[AnalysisException] { + df.select( + from_protobuf_wrapper($"protoEvent", name, descFilePathOpt).as("messageFromProto")) + .show() + } + assert(e.getMessage.contains( + "Found recursive reference in Protobuf schema, which can not be processed by Spark" + )) + } + } + + test("Verify recursion field with complex schema with recursive.fields.max.depth") { + val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "Employee") + + val manager = Employee.newBuilder().setFirstName("firstName").setLastName("lastName").build() + val em2 = EM2.newBuilder().setTeamsize(100).setEm2Manager(manager).build() + val em = EM.newBuilder().setTeamsize(100).setEmManager(manager).build() + val ic = IC.newBuilder().addSkills("java").setIcManager(manager).build() + val employee = Employee.newBuilder().setFirstName("firstName") + .setLastName("lastName").setEm2(em2).setEm(em).setIc(ic).build() + + val df = Seq(employee.toByteArray).toDF("protoEvent") + val options = new java.util.HashMap[String, String]() + options.put("recursive.fields.max.depth", "1") + + val fromProtoDf = df.select( + functions.from_protobuf($"protoEvent", "Employee", testFileDesc, options) as 'sample) + + val toDf = fromProtoDf.select( + functions.to_protobuf($"sample", "Employee", testFileDesc) as 'toProto) + val toFromDf = toDf.select( + functions.from_protobuf($"toProto", + "Employee", + testFileDesc, + options) as 'fromToProto) + + checkAnswer(fromProtoDf, toFromDf) + + val actualFieldNames = fromProtoDf.select("sample.*").schema.fields.toSeq.map(f => f.name) + descriptor.getFields.asScala.map(f => { + assert(actualFieldNames.contains(f.getName)) + }) + + val eventFromSpark = Employee.parseFrom( + toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0)) + + assert(eventFromSpark.getIc.getIcManager.getFirstName.equals("firstName")) + assert(eventFromSpark.getIc.getIcManager.getLastName.equals("lastName")) + assert(eventFromSpark.getEm2.getEm2Manager.getFirstName.isEmpty) + } + + test("Verify OneOf field with recursive fields between from_protobuf -> to_protobuf." + + "and struct -> from_protobuf") { + val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "OneOfEventWithRecursion") + + val nestedTwo = OneOfEventWithRecursion.newBuilder() + .setKey("keyNested2").setValue("valueNested2").build() + val nestedOne = EventRecursiveA.newBuilder() + .setKey("keyNested1") + .setRecursiveA(nestedTwo).build() + val oneOfRecursionEvent = OneOfEventWithRecursion.newBuilder() + .setKey("keyNested0") + .setValue("valueNested0") + .setRecursiveA(nestedOne).build() + val recursiveA = EventRecursiveA.newBuilder().setKey("recursiveAKey") + .setRecursiveA(oneOfRecursionEvent).build() + val recursiveB = EventRecursiveB.newBuilder() + .setKey("recursiveBKey") + .setValue("recursiveBvalue").build() + val oneOfEventWithRecursion = OneOfEventWithRecursion.newBuilder() + .setKey("key") + .setValue("value") + .setRecursiveB(recursiveB) + .setRecursiveA(recursiveA).build() + + val df = Seq(oneOfEventWithRecursion.toByteArray).toDF("value") + + val options = new java.util.HashMap[String, String]() + options.put("recursive.fields.max.depth", "1") + + val fromProtoDf = df.select( + functions.from_protobuf($"value", + "OneOfEventWithRecursion", + testFileDesc, options) as 'sample) + val toDf = fromProtoDf.select( + functions.to_protobuf($"sample", "OneOfEventWithRecursion", testFileDesc) as 'toProto) + val toFromDf = toDf.select( + functions.from_protobuf($"toProto", + "OneOfEventWithRecursion", + testFileDesc, + options) as 'fromToProto) + + checkAnswer(fromProtoDf, toFromDf) + + val actualFieldNames = fromProtoDf.select("sample.*").schema.fields.toSeq.map(f => f.name) + descriptor.getFields.asScala.map(f => { + assert(actualFieldNames.contains(f.getName)) + }) + + val eventFromSpark = OneOfEventWithRecursion.parseFrom( + toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0)) + + var recursiveField = eventFromSpark.getRecursiveA.getRecursiveA + assert(recursiveField.getKey.equals("keyNested0")) + assert(recursiveField.getValue.equals("valueNested0")) + assert(recursiveField.getRecursiveA.getKey.equals("keyNested1")) + assert(recursiveField.getRecursiveA.getRecursiveA.getKey.isEmpty()) + + val expectedFields = descriptor.getFields.asScala.map(f => f.getName) + eventFromSpark.getDescriptorForType.getFields.asScala.map(f => { + assert(expectedFields.contains(f.getName)) + }) + + val jsonSchema = + s""" + |{ + | "type" : "struct", + | "fields" : [ { + | "name" : "sample", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "key", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "recursiveA", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "recursiveA", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "key", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "recursiveA", + | "type" : "void", + | "nullable" : true + | }, { + | "name" : "recursiveB", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "key", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "value", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "recursiveA", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "key", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "recursiveA", + | "type" : "void", + | "nullable" : true + | }, { + | "name" : "recursiveB", + | "type" : "void", + | "nullable" : true + | }, { + | "name" : "value", + | "type" : "string", + | "nullable" : true + | } ] + | }, + | "nullable" : true + | } ] + | }, + | "nullable" : true + | }, { + | "name" : "value", + | "type" : "string", + | "nullable" : true + | } ] + | }, + | "nullable" : true + | }, { + | "name" : "key", + | "type" : "string", + | "nullable" : true + | } ] + | }, + | "nullable" : true + | }, { + | "name" : "recursiveB", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "key", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "value", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "recursiveA", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "key", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "recursiveA", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "recursiveA", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "key", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "recursiveA", + | "type" : "void", + | "nullable" : true + | }, { + | "name" : "recursiveB", + | "type" : "void", + | "nullable" : true + | }, { + | "name" : "value", + | "type" : "string", + | "nullable" : true + | } ] + | }, + | "nullable" : true + | }, { + | "name" : "key", + | "type" : "string", + | "nullable" : true + | } ] + | }, + | "nullable" : true + | }, { + | "name" : "recursiveB", + | "type" : "void", + | "nullable" : true + | }, { + | "name" : "value", + | "type" : "string", + | "nullable" : true + | } ] + | }, + | "nullable" : true + | } ] + | }, + | "nullable" : true + | }, { + | "name" : "value", + | "type" : "string", + | "nullable" : true + | } ] + | }, + | "nullable" : true + | } ] + |} + |""".stripMargin + + val schema = DataType.fromJson(jsonSchema).asInstanceOf[StructType] + val data = Seq( + Row( + Row("key1", + Row( + Row("keyNested0", null, null, "valueNested0"), + "recursiveAKey"), + null, + "value1") + ) + ) + val dataDf = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + val dataDfToProto = dataDf.select( + functions.to_protobuf($"sample", "OneOfEventWithRecursion", testFileDesc) as 'toProto) + + val eventFromSparkSchema = OneOfEventWithRecursion.parseFrom( + dataDfToProto.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0)) + recursiveField = eventFromSparkSchema.getRecursiveA.getRecursiveA + assert(recursiveField.getKey.equals("keyNested0")) + assert(recursiveField.getValue.equals("valueNested0")) + assert(recursiveField.getRecursiveA.getKey.isEmpty()) + eventFromSparkSchema.getDescriptorForType.getFields.asScala.map(f => { + assert(expectedFields.contains(f.getName)) + }) + } + + test("Verify recursive.fields.max.depth Levels 0,1, and 2 with Simple Schema") { + val eventPerson3 = EventPerson.newBuilder().setName("person3").build() + val eventPerson2 = EventPerson.newBuilder().setName("person2").setBff(eventPerson3).build() + val eventPerson1 = EventPerson.newBuilder().setName("person1").setBff(eventPerson2).build() + val eventPerson0 = EventPerson.newBuilder().setName("person0").setBff(eventPerson1).build() + val df = Seq(eventPerson0.toByteArray).toDF("value") + + val optionsZero = new java.util.HashMap[String, String]() + optionsZero.put("recursive.fields.max.depth", "0") + val schemaZero = DataType.fromJson( + s"""{ + | "type" : "struct", + | "fields" : [ { + | "name" : "sample", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "name", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "bff", + | "type" : "void", + | "nullable" : true + | } ] + | }, + | "nullable" : true + | } ] + |}""".stripMargin).asInstanceOf[StructType] + val expectedDfZero = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(Row("person0", null)))), schemaZero) + + testFromProtobufWithOptions(df, expectedDfZero, optionsZero) + + val optionsOne = new java.util.HashMap[String, String]() + optionsOne.put("recursive.fields.max.depth", "1") + val schemaOne = DataType.fromJson( + s"""{ + | "type" : "struct", + | "fields" : [ { + | "name" : "sample", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "name", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "bff", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "name", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "bff", + | "type" : "void", + | "nullable" : true + | } ] + | }, + | "nullable" : true + | } ] + | }, + | "nullable" : true + | } ] + |}""".stripMargin).asInstanceOf[StructType] + val expectedDfOne = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(Row("person0", Row("person1", null))))), schemaOne) + testFromProtobufWithOptions(df, expectedDfOne, optionsOne) + + val optionsTwo = new java.util.HashMap[String, String]() + optionsTwo.put("recursive.fields.max.depth", "2") + val schemaTwo = DataType.fromJson( + s"""{ + | "type" : "struct", + | "fields" : [ { + | "name" : "sample", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "name", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "bff", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "name", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "bff", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "name", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "bff", + | "type" : "void", + | "nullable" : true + | } ] + | }, + | "nullable" : true + | } ] + | }, + | "nullable" : true + | } ] + | }, + | "nullable" : true + | } ] + |}""".stripMargin).asInstanceOf[StructType] + val expectedDfTwo = spark.createDataFrame(spark.sparkContext.parallelize( + Seq(Row(Row("person0", Row("person1", Row("person2", null)))))), schemaTwo) + testFromProtobufWithOptions(df, expectedDfTwo, optionsTwo) + } + + def testFromProtobufWithOptions( + df: DataFrame, + expectedDf: DataFrame, + options: java.util.HashMap[String, String]): Unit = { + val fromProtoDf = df.select( + functions.from_protobuf($"value", "EventPerson", testFileDesc, options) as 'sample) + assert(expectedDf.schema === fromProtoDf.schema) + checkAnswer(fromProtoDf, expectedDf) + } } diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index b5e846a8a89..f176726d0ce 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -1048,7 +1048,7 @@ }, "RECURSIVE_PROTOBUF_SCHEMA" : { "message" : [ - "Found recursive reference in Protobuf schema, which can not be processed by Spark: <fieldDescriptor>" + "Found recursive reference in Protobuf schema, which can not be processed by Spark by default: <fieldDescriptor>. try setting the option `recursive.fields.max.depth` 0 to 10. Going beyond 10 levels of recursion is not allowed." ] }, "RENAME_SRC_PATH_NOT_FOUND" : { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org