This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d33a59c940f [SPARK-41396][SQL][PROTOBUF] OneOf field support and 
recursion checks
d33a59c940f is described below

commit d33a59c940f0e8f0b93d91cc9e700c2cb533d54e
Author: SandishKumarHN <sanysand...@gmail.com>
AuthorDate: Wed Dec 21 09:37:15 2022 +0800

    [SPARK-41396][SQL][PROTOBUF] OneOf field support and recursion checks
    
    Oneof fields allow a message to contain one and only one of a defined set 
of field types, while recursive fields provide a way to define messages that 
can refer to themselves, allowing for the creation of complex and nested data 
structures.  with this change users will be able to use protobuf OneOf fields 
with spark-protobuf, making it a more complete and useful tool for processing 
protobuf data.
    
    **Support for circularReferenceDepth:**
    The `recursive.fields.max.depth` parameter can be specified in the 
from_protobuf options to control the maximum allowed recursion depth for a 
field. Setting `recursive.fields.max.depth` to 0 drops all-recursive fields, 
setting it to 1 allows it to be recursed once, and setting it to 2 allows it to 
be recursed twice. Attempting to set the `recursive.fields.max.depth` to a 
value greater than 10 is not allowed. If the `recursive.fields.max.depth` is 
not specified, it will default to -1;  [...]
    SQL Schema for the protobuf message
     ```
    message Person {
         string name = 1;
         Person bff = 2
    }
    ```
    will vary based on the value of `recursive.fields.max.depth`.
    ```
    0: struct<name: string, bff: null>
    1: struct<name string, bff: <name: string, bff: null>>
    2: struct<name string, bff: <name: string, bff: struct<name: string, bff: 
null>>> ...
    ```
    
    ### What changes were proposed in this pull request?
    
    - Add support for protobuf oneof field
    - Stop recursion at the first level when a recursive field is encountered. 
(instead of throwing an error)
    
    ### Why are the changes needed?
    
    Stop recursion at the first level and handle nulltype in deserilization.
    
    ### Does this PR introduce _any_ user-facing change?
    
    NA
    ### How was this patch tested?
    
    Added Unit tests for OneOf field support and recursion checks.
    Tested full support for nested OneOf fields and message types using real 
data from Kafka on a real cluster
    
    cc: rangadi mposdev21
    
    Closes #38922 from SandishKumarHN/SPARK-41396.
    
    Authored-by: SandishKumarHN <sanysand...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/protobuf/ProtobufDataToCatalyst.scala      |   2 +-
 .../spark/sql/protobuf/ProtobufDeserializer.scala  |   8 +-
 .../spark/sql/protobuf/utils/ProtobufOptions.scala |   8 +
 .../sql/protobuf/utils/SchemaConverters.scala      |  69 ++-
 .../test/resources/protobuf/functions_suite.desc   | Bin 6678 -> 8739 bytes
 .../test/resources/protobuf/functions_suite.proto  |  85 ++-
 .../sql/protobuf/ProtobufFunctionsSuite.scala      | 576 ++++++++++++++++++++-
 core/src/main/resources/error/error-classes.json   |   2 +-
 8 files changed, 721 insertions(+), 29 deletions(-)

diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
index c0997b1bd06..da44f94d5ea 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
@@ -39,7 +39,7 @@ private[protobuf] case class ProtobufDataToCatalyst(
   override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
 
   override lazy val dataType: DataType = {
-    val dt = SchemaConverters.toSqlType(messageDescriptor).dataType
+    val dt = SchemaConverters.toSqlType(messageDescriptor, 
protobufOptions).dataType
     parseMode match {
       // With PermissiveMode, the output Catalyst row might contain columns of 
null values for
       // corrupt records, even if some of the columns are not nullable in the 
user-provided schema.
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
index 46366ba268b..224e22c0f52 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
@@ -156,6 +156,9 @@ private[sql] class ProtobufDeserializer(
     (protoType.getJavaType, catalystType) match {
 
       case (null, NullType) => (updater, ordinal, _) => 
updater.setNullAt(ordinal)
+      // It is possible that this will result in data being dropped, This is 
intentional,
+      // to catch recursive fields and drop them as necessary.
+      case (MESSAGE, NullType) => (updater, ordinal, _) => 
updater.setNullAt(ordinal)
 
       // TODO: we can avoid boxing if future version of Protobuf provide 
primitive accessors.
       case (BOOLEAN, BooleanType) =>
@@ -171,7 +174,7 @@ private[sql] class ProtobufDeserializer(
         (updater, ordinal, value) => updater.setShort(ordinal, 
value.asInstanceOf[Short])
 
       case  (
-        BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING,
+        MESSAGE | BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | 
BYTE_STRING,
         ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated =>
         newArrayWriter(protoType, protoPath, catalystPath, dataType, 
containsNull)
 
@@ -235,9 +238,6 @@ private[sql] class ProtobufDeserializer(
           writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage])
           updater.set(ordinal, row)
 
-      case (MESSAGE, ArrayType(st: StructType, containsNull)) =>
-        newArrayWriter(protoType, protoPath, catalystPath, st, containsNull)
-
       case (ENUM, StringType) =>
         (updater, ordinal, value) => updater.set(ordinal, 
UTF8String.fromString(value.toString))
 
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
index 1cece0d7966..52f9f74bd43 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
@@ -38,6 +38,14 @@ private[sql] class ProtobufOptions(
 
   val parseMode: ParseMode =
     parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode)
+
+  // Setting the `recursive.fields.max.depth` to 0 drops all recursive fields,
+  // 1 allows it to be recurse once, and 2 allows it to be recursed twice and 
so on.
+  // A value of `recursive.fields.max.depth` greater than 10 is not permitted. 
If it is not
+  // specified, the default value is -1; recursive fields are not permitted. 
If a protobuf
+  // record has more depth than the allowed value for recursive fields, it 
will be truncated
+  // and some fields may be discarded.
+  val recursiveFieldMaxDepth: Int = 
parameters.getOrElse("recursive.fields.max.depth", "-1").toInt
 }
 
 private[sql] object ProtobufOptions {
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
index 4979fb9a504..8d321c13a56 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
@@ -40,19 +40,30 @@ object SchemaConverters {
    *
    * @since 3.4.0
    */
-  def toSqlType(descriptor: Descriptor): SchemaType = {
-    toSqlTypeHelper(descriptor)
+  def toSqlType(
+      descriptor: Descriptor,
+      protobufOptions: ProtobufOptions = ProtobufOptions(Map.empty)): 
SchemaType = {
+    toSqlTypeHelper(descriptor, protobufOptions)
   }
 
-  def toSqlTypeHelper(descriptor: Descriptor): SchemaType = 
ScalaReflectionLock.synchronized {
+  def toSqlTypeHelper(
+      descriptor: Descriptor,
+      protobufOptions: ProtobufOptions): SchemaType = 
ScalaReflectionLock.synchronized {
     SchemaType(
-      StructType(descriptor.getFields.asScala.flatMap(structFieldFor(_, 
Set.empty)).toArray),
+      StructType(descriptor.getFields.asScala.flatMap(
+        structFieldFor(_,
+          Map(descriptor.getFullName -> 1),
+          protobufOptions: ProtobufOptions)).toArray),
       nullable = true)
   }
 
+  // existingRecordNames: Map[String, Int] used to track the depth of 
recursive fields and to
+  // ensure that the conversion of the protobuf message to a Spark SQL 
StructType object does not
+  // exceed the maximum recursive depth specified by the 
recursiveFieldMaxDepth option.
   def structFieldFor(
       fd: FieldDescriptor,
-      existingRecordNames: Set[String]): Option[StructField] = {
+      existingRecordNames: Map[String, Int],
+      protobufOptions: ProtobufOptions): Option[StructField] = {
     import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
     val dataType = fd.getJavaType match {
       case INT => Some(IntegerType)
@@ -81,9 +92,17 @@ object SchemaConverters {
         fd.getMessageType.getFields.forEach { field =>
           field.getName match {
             case "key" =>
-              keyType = structFieldFor(field, existingRecordNames).get.dataType
+              keyType =
+                structFieldFor(
+                  field,
+                  existingRecordNames,
+                  protobufOptions).get.dataType
             case "value" =>
-              valueType = structFieldFor(field, 
existingRecordNames).get.dataType
+              valueType =
+                structFieldFor(
+                  field,
+                  existingRecordNames,
+                  protobufOptions).get.dataType
           }
         }
         return Option(
@@ -92,17 +111,35 @@ object SchemaConverters {
             MapType(keyType, valueType, valueContainsNull = 
false).defaultConcreteType,
             nullable = false))
       case MESSAGE =>
-        if (existingRecordNames.contains(fd.getFullName)) {
+        // If the `recursive.fields.max.depth` value is not specified, it will 
default to -1;
+        // recursive fields are not permitted. Setting it to 0 drops all 
recursive fields,
+        // 1 allows it to be recursed once, and 2 allows it to be recursed 
twice and so on.
+        // A value greater than 10 is not allowed, and if a protobuf record 
has more depth for
+        // recursive fields than the allowed value, it will be truncated and 
some fields may be
+        // discarded.
+        // SQL Schema for the protobuf message `message Person { string name = 
1; Person bff = 2}`
+        // will vary based on the value of "recursive.fields.max.depth".
+        // 0: struct<name: string, bff: null>
+        // 1: struct<name string, bff: <name: string, bff: null>>
+        // 2: struct<name string, bff: <name: string, bff: struct<name: 
string, bff: null>>> ...
+        val recordName = fd.getMessageType.getFullName
+        val recursiveDepth = existingRecordNames.getOrElse(recordName, 0)
+        val recursiveFieldMaxDepth = protobufOptions.recursiveFieldMaxDepth
+        if (existingRecordNames.contains(recordName) && 
(recursiveFieldMaxDepth < 0 ||
+          recursiveFieldMaxDepth > 10)) {
           throw 
QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
+        } else if (existingRecordNames.contains(recordName) &&
+          recursiveDepth > recursiveFieldMaxDepth) {
+          Some(NullType)
+        } else {
+          val newRecordNames = existingRecordNames + (recordName -> 
(recursiveDepth + 1))
+          Option(
+            fd.getMessageType.getFields.asScala
+              .flatMap(structFieldFor(_, newRecordNames, protobufOptions))
+              .toSeq)
+            .filter(_.nonEmpty)
+            .map(StructType.apply)
         }
-        val newRecordNames = existingRecordNames + fd.getFullName
-
-        Option(
-          fd.getMessageType.getFields.asScala
-            .flatMap(structFieldFor(_, newRecordNames))
-            .toSeq)
-          .filter(_.nonEmpty)
-          .map(StructType.apply)
       case other =>
         throw 
QueryCompilationErrors.protobufTypeUnsupportedYetError(other.toString)
     }
diff --git 
a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc 
b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc
index d54ee4337a5..135d489f520 100644
Binary files 
a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc and 
b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc differ
diff --git 
a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto 
b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
index 2fef8495c5e..449f1b68bb8 100644
--- a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
+++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
@@ -170,4 +170,87 @@ message timeStampMsg {
 message durationMsg {
   string key = 1;
   Duration duration = 2;
-}
\ No newline at end of file
+}
+
+message OneOfEvent {
+  string key = 1;
+  oneof payload {
+    int32 col_1 = 2;
+    string col_2 = 3;
+    int64 col_3 = 4;
+  }
+  repeated string col_4 = 5;
+}
+
+message EventWithRecursion {
+  int32 key = 1;
+  messageA a = 2;
+}
+message messageA {
+  EventWithRecursion a = 1;
+  messageB b = 2;
+}
+message messageB {
+  EventWithRecursion aa = 1;
+  messageC c = 2;
+}
+message messageC {
+  EventWithRecursion aaa = 1;
+  int32 key= 2;
+}
+
+message Employee {
+  string firstName = 1;
+  string lastName = 2;
+  oneof role {
+    IC ic = 3;
+    EM em = 4;
+    EM2 em2 = 5;
+  }
+}
+
+message IC {
+  repeated string skills = 1;
+  Employee icManager = 2;
+}
+
+message EM {
+  int64 teamsize = 1;
+  Employee emManager = 2;
+}
+
+message EM2 {
+  int64 teamsize = 1;
+  Employee em2Manager = 2;
+}
+
+message EventPerson {
+  string name = 1;
+  EventPerson bff = 2;
+}
+
+message OneOfEventWithRecursion {
+  string key = 1;
+  oneof payload {
+    EventRecursiveA recursiveA = 3;
+    EventRecursiveB recursiveB = 6;
+  }
+  string value = 7;
+}
+
+message EventRecursiveA {
+  OneOfEventWithRecursion recursiveA = 1;
+  string key = 2;
+}
+
+message EventRecursiveB {
+  string key = 1;
+  string value = 2;
+  OneOfEventWithRecursion recursiveA = 3;
+}
+
+message Status {
+  int32 id = 1;
+  Timestamp trade_time = 2;
+  Status status = 3;
+}
diff --git 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
index e493bc66ca7..79d7f96414b 100644
--- 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
+++ 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -23,14 +23,13 @@ import scala.collection.JavaConverters._
 
 import com.google.protobuf.{ByteString, DynamicMessage}
 
-import org.apache.spark.sql.{Column, QueryTest, Row}
-import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, 
Row}
 import org.apache.spark.sql.functions.{lit, struct}
-import 
org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated
+import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{messageA, 
messageB, messageC, EM, EM2, Employee, EventPerson, EventRecursiveA, 
EventRecursiveB, EventWithRecursion, IC, OneOfEvent, OneOfEventWithRecursion, 
SimpleMessageRepeated}
 import 
org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum
 import org.apache.spark.sql.protobuf.utils.ProtobufUtils
 import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.{DayTimeIntervalType, IntegerType, 
StringType, StructField, StructType, TimestampType}
+import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, IntegerType, 
StringType, StructField, StructType, TimestampType}
 
 class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with 
ProtobufTestBase
   with Serializable {
@@ -417,7 +416,7 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
             .show()
         }
         assert(e.getMessage.contains(
-          "Found recursive reference in Protobuf schema, which can not be 
processed by Spark:"
+          "Found recursive reference in Protobuf schema, which can not be 
processed by Spark"
         ))
     }
   }
@@ -453,7 +452,7 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
             .show()
         }
         assert(e.getMessage.contains(
-          "Found recursive reference in Protobuf schema, which can not be 
processed by Spark:"
+          "Found recursive reference in Protobuf schema, which can not be 
processed by Spark"
         ))
     }
   }
@@ -693,4 +692,569 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
       errorClass = "CANNOT_CONSTRUCT_PROTOBUF_DESCRIPTOR",
       parameters = Map("descFilePath" -> testFileDescriptor))
   }
+
+  test("Verify OneOf field between from_protobuf -> to_protobuf and struct -> 
from_protobuf") {
+    val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "OneOfEvent")
+    val oneOfEvent = OneOfEvent.newBuilder()
+      .setKey("key")
+      .setCol1(123)
+      .setCol3(109202L)
+      .setCol2("col2value")
+      .addCol4("col4value").build()
+
+    val df = Seq(oneOfEvent.toByteArray).toDF("value")
+
+    checkWithFileAndClassName("OneOfEvent") {
+      case (name, descFilePathOpt) =>
+        val fromProtoDf = df.select(
+          from_protobuf_wrapper($"value", name, descFilePathOpt) as 'sample)
+        val toDf = fromProtoDf.select(
+          to_protobuf_wrapper($"sample", name, descFilePathOpt) as 'toProto)
+        val toFromDf = toDf.select(
+          from_protobuf_wrapper($"toProto", name, descFilePathOpt) as 
'fromToProto)
+        checkAnswer(fromProtoDf, toFromDf)
+        val actualFieldNames = 
fromProtoDf.select("sample.*").schema.fields.toSeq.map(f => f.name)
+        descriptor.getFields.asScala.map(f => {
+          assert(actualFieldNames.contains(f.getName))
+        })
+
+        val eventFromSpark = OneOfEvent.parseFrom(
+          toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
+        // OneOf field: the last set value(by order) will overwrite all 
previous ones.
+        assert(eventFromSpark.getCol2.equals("col2value"))
+        assert(eventFromSpark.getCol3 == 0)
+        val expectedFields = descriptor.getFields.asScala.map(f => f.getName)
+        eventFromSpark.getDescriptorForType.getFields.asScala.map(f => {
+          assert(expectedFields.contains(f.getName))
+        })
+
+        val jsonSchema =
+          s"""
+             |{
+             |  "type" : "struct",
+             |  "fields" : [ {
+             |    "name" : "sample",
+             |    "type" : {
+             |      "type" : "struct",
+             |      "fields" : [ {
+             |        "name" : "key",
+             |        "type" : "string",
+             |        "nullable" : true
+             |      }, {
+             |        "name" : "col_1",
+             |        "type" : "integer",
+             |        "nullable" : true
+             |      }, {
+             |        "name" : "col_2",
+             |        "type" : "string",
+             |        "nullable" : true
+             |      }, {
+             |        "name" : "col_3",
+             |        "type" : "long",
+             |        "nullable" : true
+             |      }, {
+             |        "name" : "col_4",
+             |        "type" : {
+             |          "type" : "array",
+             |          "elementType" : "string",
+             |          "containsNull" : false
+             |        },
+             |        "nullable" : false
+             |      } ]
+             |    },
+             |    "nullable" : true
+             |  } ]
+             |}
+             |{
+             |  "type" : "struct",
+             |  "fields" : [ {
+             |    "name" : "sample",
+             |    "type" : {
+             |      "type" : "struct",
+             |      "fields" : [ {
+             |        "name" : "key",
+             |        "type" : "string",
+             |        "nullable" : true
+             |      }, {
+             |        "name" : "col_1",
+             |        "type" : "integer",
+             |        "nullable" : true
+             |      }, {
+             |        "name" : "col_2",
+             |        "type" : "string",
+             |        "nullable" : true
+             |      }, {
+             |        "name" : "col_3",
+             |        "type" : "long",
+             |        "nullable" : true
+             |      }, {
+             |        "name" : "col_4",
+             |        "type" : {
+             |          "type" : "array",
+             |          "elementType" : "string",
+             |          "containsNull" : false
+             |        },
+             |        "nullable" : false
+             |      } ]
+             |    },
+             |    "nullable" : true
+             |  } ]
+             |}
+             |""".stripMargin
+        val schema = DataType.fromJson(jsonSchema).asInstanceOf[StructType]
+        val data = Seq(Row(Row("key", 123, "col2value", 109202L, 
Seq("col4value"))))
+        val dataDf = 
spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+        val dataDfToProto = dataDf.select(
+          to_protobuf_wrapper($"sample", name, descFilePathOpt) as 'toProto)
+
+        val eventFromSparkSchema = OneOfEvent.parseFrom(
+          
dataDfToProto.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
+        assert(eventFromSparkSchema.getCol2.isEmpty)
+        assert(eventFromSparkSchema.getCol3 == 109202L)
+        eventFromSparkSchema.getDescriptorForType.getFields.asScala.map(f => {
+          assert(expectedFields.contains(f.getName))
+        })
+    }
+  }
+
+  test("Fail for recursion field with complex schema without 
recursive.fields.max.depth") {
+    val aEventWithRecursion = EventWithRecursion.newBuilder().setKey(2).build()
+    val aaEventWithRecursion = 
EventWithRecursion.newBuilder().setKey(3).build()
+    val aaaEventWithRecursion = 
EventWithRecursion.newBuilder().setKey(4).build()
+    val c = messageC.newBuilder().setAaa(aaaEventWithRecursion).setKey(12092)
+    val b = messageB.newBuilder().setAa(aaEventWithRecursion).setC(c)
+    val a = messageA.newBuilder().setA(aEventWithRecursion).setB(b).build()
+    val eventWithRecursion = 
EventWithRecursion.newBuilder().setKey(1).setA(a).build()
+
+    val df = Seq(eventWithRecursion.toByteArray).toDF("protoEvent")
+
+    checkWithFileAndClassName("EventWithRecursion") {
+      case (name, descFilePathOpt) =>
+        val e = intercept[AnalysisException] {
+          df.select(
+            from_protobuf_wrapper($"protoEvent", name, 
descFilePathOpt).as("messageFromProto"))
+            .show()
+        }
+        assert(e.getMessage.contains(
+          "Found recursive reference in Protobuf schema, which can not be 
processed by Spark"
+        ))
+    }
+  }
+
+  test("Verify recursion field with complex schema with 
recursive.fields.max.depth") {
+    val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "Employee")
+
+    val manager = 
Employee.newBuilder().setFirstName("firstName").setLastName("lastName").build()
+    val em2 = EM2.newBuilder().setTeamsize(100).setEm2Manager(manager).build()
+    val em = EM.newBuilder().setTeamsize(100).setEmManager(manager).build()
+    val ic = IC.newBuilder().addSkills("java").setIcManager(manager).build()
+    val employee = Employee.newBuilder().setFirstName("firstName")
+      .setLastName("lastName").setEm2(em2).setEm(em).setIc(ic).build()
+
+    val df = Seq(employee.toByteArray).toDF("protoEvent")
+    val options = new java.util.HashMap[String, String]()
+    options.put("recursive.fields.max.depth", "1")
+
+    val fromProtoDf = df.select(
+      functions.from_protobuf($"protoEvent", "Employee", testFileDesc, 
options) as 'sample)
+
+    val toDf = fromProtoDf.select(
+      functions.to_protobuf($"sample", "Employee", testFileDesc) as 'toProto)
+    val toFromDf = toDf.select(
+      functions.from_protobuf($"toProto",
+        "Employee",
+        testFileDesc,
+        options) as 'fromToProto)
+
+    checkAnswer(fromProtoDf, toFromDf)
+
+    val actualFieldNames = 
fromProtoDf.select("sample.*").schema.fields.toSeq.map(f => f.name)
+    descriptor.getFields.asScala.map(f => {
+      assert(actualFieldNames.contains(f.getName))
+    })
+
+    val eventFromSpark = Employee.parseFrom(
+      toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
+
+    assert(eventFromSpark.getIc.getIcManager.getFirstName.equals("firstName"))
+    assert(eventFromSpark.getIc.getIcManager.getLastName.equals("lastName"))
+    assert(eventFromSpark.getEm2.getEm2Manager.getFirstName.isEmpty)
+  }
+
+  test("Verify OneOf field with recursive fields between from_protobuf -> 
to_protobuf." +
+    "and struct -> from_protobuf") {
+    val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, 
"OneOfEventWithRecursion")
+
+    val nestedTwo = OneOfEventWithRecursion.newBuilder()
+      .setKey("keyNested2").setValue("valueNested2").build()
+    val nestedOne = EventRecursiveA.newBuilder()
+      .setKey("keyNested1")
+      .setRecursiveA(nestedTwo).build()
+    val oneOfRecursionEvent = OneOfEventWithRecursion.newBuilder()
+      .setKey("keyNested0")
+      .setValue("valueNested0")
+      .setRecursiveA(nestedOne).build()
+    val recursiveA = EventRecursiveA.newBuilder().setKey("recursiveAKey")
+      .setRecursiveA(oneOfRecursionEvent).build()
+    val recursiveB = EventRecursiveB.newBuilder()
+      .setKey("recursiveBKey")
+      .setValue("recursiveBvalue").build()
+    val oneOfEventWithRecursion = OneOfEventWithRecursion.newBuilder()
+      .setKey("key")
+      .setValue("value")
+      .setRecursiveB(recursiveB)
+      .setRecursiveA(recursiveA).build()
+
+    val df = Seq(oneOfEventWithRecursion.toByteArray).toDF("value")
+
+    val options = new java.util.HashMap[String, String]()
+    options.put("recursive.fields.max.depth", "1")
+
+    val fromProtoDf = df.select(
+      functions.from_protobuf($"value",
+        "OneOfEventWithRecursion",
+        testFileDesc, options) as 'sample)
+    val toDf = fromProtoDf.select(
+      functions.to_protobuf($"sample", "OneOfEventWithRecursion", 
testFileDesc) as 'toProto)
+    val toFromDf = toDf.select(
+      functions.from_protobuf($"toProto",
+        "OneOfEventWithRecursion",
+        testFileDesc,
+        options) as 'fromToProto)
+
+    checkAnswer(fromProtoDf, toFromDf)
+
+    val actualFieldNames = 
fromProtoDf.select("sample.*").schema.fields.toSeq.map(f => f.name)
+    descriptor.getFields.asScala.map(f => {
+      assert(actualFieldNames.contains(f.getName))
+    })
+
+    val eventFromSpark = OneOfEventWithRecursion.parseFrom(
+      toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
+
+    var recursiveField = eventFromSpark.getRecursiveA.getRecursiveA
+    assert(recursiveField.getKey.equals("keyNested0"))
+    assert(recursiveField.getValue.equals("valueNested0"))
+    assert(recursiveField.getRecursiveA.getKey.equals("keyNested1"))
+    assert(recursiveField.getRecursiveA.getRecursiveA.getKey.isEmpty())
+
+    val expectedFields = descriptor.getFields.asScala.map(f => f.getName)
+    eventFromSpark.getDescriptorForType.getFields.asScala.map(f => {
+      assert(expectedFields.contains(f.getName))
+    })
+
+    val jsonSchema =
+      s"""
+         |{
+         |  "type" : "struct",
+         |  "fields" : [ {
+         |    "name" : "sample",
+         |    "type" : {
+         |      "type" : "struct",
+         |      "fields" : [ {
+         |        "name" : "key",
+         |        "type" : "string",
+         |        "nullable" : true
+         |      }, {
+         |        "name" : "recursiveA",
+         |        "type" : {
+         |          "type" : "struct",
+         |          "fields" : [ {
+         |            "name" : "recursiveA",
+         |            "type" : {
+         |              "type" : "struct",
+         |              "fields" : [ {
+         |                "name" : "key",
+         |                "type" : "string",
+         |                "nullable" : true
+         |              }, {
+         |                "name" : "recursiveA",
+         |                "type" : "void",
+         |                "nullable" : true
+         |              }, {
+         |                "name" : "recursiveB",
+         |                "type" : {
+         |                  "type" : "struct",
+         |                  "fields" : [ {
+         |                    "name" : "key",
+         |                    "type" : "string",
+         |                    "nullable" : true
+         |                  }, {
+         |                    "name" : "value",
+         |                    "type" : "string",
+         |                    "nullable" : true
+         |                  }, {
+         |                    "name" : "recursiveA",
+         |                    "type" : {
+         |                      "type" : "struct",
+         |                      "fields" : [ {
+         |                        "name" : "key",
+         |                        "type" : "string",
+         |                        "nullable" : true
+         |                      }, {
+         |                        "name" : "recursiveA",
+         |                        "type" : "void",
+         |                        "nullable" : true
+         |                      }, {
+         |                        "name" : "recursiveB",
+         |                        "type" : "void",
+         |                        "nullable" : true
+         |                      }, {
+         |                        "name" : "value",
+         |                        "type" : "string",
+         |                        "nullable" : true
+         |                      } ]
+         |                    },
+         |                    "nullable" : true
+         |                  } ]
+         |                },
+         |                "nullable" : true
+         |              }, {
+         |                "name" : "value",
+         |                "type" : "string",
+         |                "nullable" : true
+         |              } ]
+         |            },
+         |            "nullable" : true
+         |          }, {
+         |            "name" : "key",
+         |            "type" : "string",
+         |            "nullable" : true
+         |          } ]
+         |        },
+         |        "nullable" : true
+         |      }, {
+         |        "name" : "recursiveB",
+         |        "type" : {
+         |          "type" : "struct",
+         |          "fields" : [ {
+         |            "name" : "key",
+         |            "type" : "string",
+         |            "nullable" : true
+         |          }, {
+         |            "name" : "value",
+         |            "type" : "string",
+         |            "nullable" : true
+         |          }, {
+         |            "name" : "recursiveA",
+         |            "type" : {
+         |              "type" : "struct",
+         |              "fields" : [ {
+         |                "name" : "key",
+         |                "type" : "string",
+         |                "nullable" : true
+         |              }, {
+         |                "name" : "recursiveA",
+         |                "type" : {
+         |                  "type" : "struct",
+         |                  "fields" : [ {
+         |                    "name" : "recursiveA",
+         |                    "type" : {
+         |                      "type" : "struct",
+         |                      "fields" : [ {
+         |                        "name" : "key",
+         |                        "type" : "string",
+         |                        "nullable" : true
+         |                      }, {
+         |                        "name" : "recursiveA",
+         |                        "type" : "void",
+         |                        "nullable" : true
+         |                      }, {
+         |                        "name" : "recursiveB",
+         |                        "type" : "void",
+         |                        "nullable" : true
+         |                      }, {
+         |                        "name" : "value",
+         |                        "type" : "string",
+         |                        "nullable" : true
+         |                      } ]
+         |                    },
+         |                    "nullable" : true
+         |                  }, {
+         |                    "name" : "key",
+         |                    "type" : "string",
+         |                    "nullable" : true
+         |                  } ]
+         |                },
+         |                "nullable" : true
+         |              }, {
+         |                "name" : "recursiveB",
+         |                "type" : "void",
+         |                "nullable" : true
+         |              }, {
+         |                "name" : "value",
+         |                "type" : "string",
+         |                "nullable" : true
+         |              } ]
+         |            },
+         |            "nullable" : true
+         |          } ]
+         |        },
+         |        "nullable" : true
+         |      }, {
+         |        "name" : "value",
+         |        "type" : "string",
+         |        "nullable" : true
+         |      } ]
+         |    },
+         |    "nullable" : true
+         |  } ]
+         |}
+         |""".stripMargin
+
+    val schema = DataType.fromJson(jsonSchema).asInstanceOf[StructType]
+    val data = Seq(
+      Row(
+        Row("key1",
+          Row(
+            Row("keyNested0", null, null, "valueNested0"),
+            "recursiveAKey"),
+          null,
+          "value1")
+      )
+    )
+    val dataDf = spark.createDataFrame(spark.sparkContext.parallelize(data), 
schema)
+    val dataDfToProto = dataDf.select(
+      functions.to_protobuf($"sample", "OneOfEventWithRecursion", 
testFileDesc) as 'toProto)
+
+    val eventFromSparkSchema = OneOfEventWithRecursion.parseFrom(
+      dataDfToProto.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
+    recursiveField = eventFromSparkSchema.getRecursiveA.getRecursiveA
+    assert(recursiveField.getKey.equals("keyNested0"))
+    assert(recursiveField.getValue.equals("valueNested0"))
+    assert(recursiveField.getRecursiveA.getKey.isEmpty())
+    eventFromSparkSchema.getDescriptorForType.getFields.asScala.map(f => {
+      assert(expectedFields.contains(f.getName))
+    })
+  }
+
+  test("Verify recursive.fields.max.depth Levels 0,1, and 2 with Simple 
Schema") {
+    val eventPerson3 = EventPerson.newBuilder().setName("person3").build()
+    val eventPerson2 = 
EventPerson.newBuilder().setName("person2").setBff(eventPerson3).build()
+    val eventPerson1 = 
EventPerson.newBuilder().setName("person1").setBff(eventPerson2).build()
+    val eventPerson0 = 
EventPerson.newBuilder().setName("person0").setBff(eventPerson1).build()
+    val df = Seq(eventPerson0.toByteArray).toDF("value")
+
+    val optionsZero = new java.util.HashMap[String, String]()
+    optionsZero.put("recursive.fields.max.depth", "0")
+    val schemaZero = DataType.fromJson(
+        s"""{
+           |  "type" : "struct",
+           |  "fields" : [ {
+           |    "name" : "sample",
+           |    "type" : {
+           |      "type" : "struct",
+           |      "fields" : [ {
+           |        "name" : "name",
+           |        "type" : "string",
+           |        "nullable" : true
+           |      }, {
+           |        "name" : "bff",
+           |        "type" : "void",
+           |        "nullable" : true
+           |      } ]
+           |    },
+           |    "nullable" : true
+           |  } ]
+           |}""".stripMargin).asInstanceOf[StructType]
+    val expectedDfZero = spark.createDataFrame(
+      spark.sparkContext.parallelize(Seq(Row(Row("person0", null)))), 
schemaZero)
+
+    testFromProtobufWithOptions(df, expectedDfZero, optionsZero)
+
+    val optionsOne = new java.util.HashMap[String, String]()
+    optionsOne.put("recursive.fields.max.depth", "1")
+    val schemaOne = DataType.fromJson(
+      s"""{
+         |  "type" : "struct",
+         |  "fields" : [ {
+         |    "name" : "sample",
+         |    "type" : {
+         |      "type" : "struct",
+         |      "fields" : [ {
+         |        "name" : "name",
+         |        "type" : "string",
+         |        "nullable" : true
+         |      }, {
+         |        "name" : "bff",
+         |        "type" : {
+         |          "type" : "struct",
+         |          "fields" : [ {
+         |            "name" : "name",
+         |            "type" : "string",
+         |            "nullable" : true
+         |          }, {
+         |            "name" : "bff",
+         |            "type" : "void",
+         |            "nullable" : true
+         |          } ]
+         |        },
+         |        "nullable" : true
+         |      } ]
+         |    },
+         |    "nullable" : true
+         |  } ]
+         |}""".stripMargin).asInstanceOf[StructType]
+    val expectedDfOne = spark.createDataFrame(
+      spark.sparkContext.parallelize(Seq(Row(Row("person0", Row("person1", 
null))))), schemaOne)
+    testFromProtobufWithOptions(df, expectedDfOne, optionsOne)
+
+    val optionsTwo = new java.util.HashMap[String, String]()
+    optionsTwo.put("recursive.fields.max.depth", "2")
+    val schemaTwo = DataType.fromJson(
+      s"""{
+         |  "type" : "struct",
+         |  "fields" : [ {
+         |    "name" : "sample",
+         |    "type" : {
+         |      "type" : "struct",
+         |      "fields" : [ {
+         |        "name" : "name",
+         |        "type" : "string",
+         |        "nullable" : true
+         |      }, {
+         |        "name" : "bff",
+         |        "type" : {
+         |          "type" : "struct",
+         |          "fields" : [ {
+         |            "name" : "name",
+         |            "type" : "string",
+         |            "nullable" : true
+         |          }, {
+         |            "name" : "bff",
+         |            "type" : {
+         |              "type" : "struct",
+         |              "fields" : [ {
+         |                "name" : "name",
+         |                "type" : "string",
+         |                "nullable" : true
+         |              }, {
+         |                "name" : "bff",
+         |                "type" : "void",
+         |                "nullable" : true
+         |              } ]
+         |            },
+         |            "nullable" : true
+         |          } ]
+         |        },
+         |        "nullable" : true
+         |      } ]
+         |    },
+         |    "nullable" : true
+         |  } ]
+         |}""".stripMargin).asInstanceOf[StructType]
+    val expectedDfTwo = spark.createDataFrame(spark.sparkContext.parallelize(
+      Seq(Row(Row("person0", Row("person1", Row("person2", null)))))), 
schemaTwo)
+    testFromProtobufWithOptions(df, expectedDfTwo, optionsTwo)
+  }
+
+  def testFromProtobufWithOptions(
+    df: DataFrame,
+    expectedDf: DataFrame,
+    options: java.util.HashMap[String, String]): Unit = {
+    val fromProtoDf = df.select(
+      functions.from_protobuf($"value", "EventPerson", testFileDesc, options) 
as 'sample)
+    assert(expectedDf.schema === fromProtoDf.schema)
+    checkAnswer(fromProtoDf, expectedDf)
+  }
 }
diff --git a/core/src/main/resources/error/error-classes.json 
b/core/src/main/resources/error/error-classes.json
index b5e846a8a89..f176726d0ce 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -1048,7 +1048,7 @@
   },
   "RECURSIVE_PROTOBUF_SCHEMA" : {
     "message" : [
-      "Found recursive reference in Protobuf schema, which can not be 
processed by Spark: <fieldDescriptor>"
+      "Found recursive reference in Protobuf schema, which can not be 
processed by Spark by default: <fieldDescriptor>. try setting the option 
`recursive.fields.max.depth` 0 to 10. Going beyond 10 levels of recursion is 
not allowed."
     ]
   },
   "RENAME_SRC_PATH_NOT_FOUND" : {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to