This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 72b40677a27 [SPARK-42406] Terminate Protobuf recursive fields by 
dropping the field
72b40677a27 is described below

commit 72b40677a27b3e86fb7f98d46ccd86d650e4f2db
Author: Raghu Angadi <raghu.ang...@databricks.com>
AuthorDate: Mon Feb 27 20:44:17 2023 -0800

    [SPARK-42406] Terminate Protobuf recursive fields by dropping the field
    
    ### What changes were proposed in this pull request?
    
    Protobuf deserializer (`from_protobuf()` function()) optionally supports 
recursive fields up to certain depth. Currently it uses `NullType` to terminate 
the recursion. But an `ArrayType` containing `NullType` is not really useful 
and  it does not work delta.
    
    This PR fixes this by removing the field to terminate recursion rather than 
using `NullType`.
    The following example illustrates the difference.
    
    E.g. Consider a recursive Protobuf like this:
    ```
    message Node {
        int value = 1;
        repeated Node children = 2  // recursive array
    }
    message Tree {
        Node root = 1
    }
    ```
    Catalyst schama with `from_protobuf()` of `Tree` with max recursive depth 
set to 2, would be:
    
       - **Before**:  _STRUCT<root: STRUCT<value: int, children: 
array<STRUCT<value: int, **children: array< void >**>>>>_
       - **After**: _STRUCT<root: STRUCT<value: int, children: 
array<STRUCT<value: int>>>>_
    
    Notice that at second level, the `children` array is dropped, rather than 
being defined as `array<void>`.
    
    ### Why are the changes needed?
     - This improves how Protobuf connector handles recursive fields. It avoids 
using `void` fields which are problematic in many scenarios and do not add any 
information.
    
    ### Does this PR introduce _any_ user-facing change?
     - This changes the schema in a subtle manner while using with recursive 
support enabled. Since this only removes an optional field, it is backward 
compatible.
    
    ### How was this patch tested?
     - Added multiple unit tests and updated existing one. Most of the changes 
for this PR are in the tests.
    
    Closes #40141 from rangadi/recursive-fields.
    
    Authored-by: Raghu Angadi <raghu.ang...@databricks.com>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
    (cherry picked from commit e397a1585d3b089f155ccb4359d5525cd012d5da)
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../spark/sql/protobuf/ProtobufDeserializer.scala  |   3 -
 .../sql/protobuf/utils/SchemaConverters.scala      |  84 ++-
 .../test/resources/protobuf/functions_suite.desc   | Bin 8836 -> 9648 bytes
 .../test/resources/protobuf/functions_suite.proto  |  29 +-
 .../sql/protobuf/ProtobufFunctionsSuite.scala      | 619 ++++++++-------------
 5 files changed, 301 insertions(+), 434 deletions(-)

diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
index 37278fab8a3..7723687a4d9 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
@@ -157,9 +157,6 @@ private[sql] class ProtobufDeserializer(
     (protoType.getJavaType, catalystType) match {
 
       case (null, NullType) => (updater, ordinal, _) => 
updater.setNullAt(ordinal)
-      // It is possible that this will result in data being dropped, This is 
intentional,
-      // to catch recursive fields and drop them as necessary.
-      case (MESSAGE, NullType) => (updater, ordinal, _) => 
updater.setNullAt(ordinal)
 
       // TODO: we can avoid boxing if future version of Protobuf provide 
primitive accessors.
       case (BOOLEAN, BooleanType) =>
diff --git 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
index 9666e34bab4..e277f2999e4 100644
--- 
a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
+++ 
b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
@@ -21,11 +21,12 @@ import scala.collection.JavaConverters._
 import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
 
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.types._
 
 @DeveloperApi
-object SchemaConverters {
+object SchemaConverters extends Logging {
 
   /**
    * Internal wrapper for SQL data type and nullability.
@@ -59,6 +60,8 @@ object SchemaConverters {
   // existingRecordNames: Map[String, Int] used to track the depth of 
recursive fields and to
   // ensure that the conversion of the protobuf message to a Spark SQL 
StructType object does not
   // exceed the maximum recursive depth specified by the 
recursiveFieldMaxDepth option.
+  // A return of None implies the field has reached the maximum allowed 
recursive depth and
+  // should be dropped.
   def structFieldFor(
       fd: FieldDescriptor,
       existingRecordNames: Map[String, Int],
@@ -84,10 +87,10 @@ object SchemaConverters {
           fd.getMessageType.getFields.size() == 2 &&
           fd.getMessageType.getFields.get(0).getName.equals("seconds") &&
           fd.getMessageType.getFields.get(1).getName.equals("nanos")) =>
-          Some(TimestampType)
+        Some(TimestampType)
       case MESSAGE if fd.isRepeated && 
fd.getMessageType.getOptions.hasMapEntry =>
-        var keyType: DataType = NullType
-        var valueType: DataType = NullType
+        var keyType: Option[DataType] = None
+        var valueType: Option[DataType] = None
         fd.getMessageType.getFields.forEach { field =>
           field.getName match {
             case "key" =>
@@ -95,32 +98,42 @@ object SchemaConverters {
                 structFieldFor(
                   field,
                   existingRecordNames,
-                  protobufOptions).get.dataType
+                  protobufOptions).map(_.dataType)
             case "value" =>
               valueType =
                 structFieldFor(
                   field,
                   existingRecordNames,
-                  protobufOptions).get.dataType
+                  protobufOptions).map(_.dataType)
           }
         }
-        return Option(
-          StructField(
-            fd.getName,
-            MapType(keyType, valueType, valueContainsNull = 
false).defaultConcreteType,
-            nullable = false))
+        (keyType, valueType) match {
+          case (None, _) =>
+            // This is probably never expected. Protobuf does not allow 
complex types for keys.
+            log.info(s"Dropping map field ${fd.getFullName}. Key reached max 
recursive depth.")
+            None
+          case (_, None) =>
+            log.info(s"Dropping map field ${fd.getFullName}. Value reached max 
recursive depth.")
+            None
+          case (Some(kt), Some(vt)) => Some(MapType(kt, vt, valueContainsNull 
= false))
+        }
       case MESSAGE =>
-        // If the `recursive.fields.max.depth` value is not specified, it will 
default to -1;
-        // recursive fields are not permitted. Setting it to 0 drops all 
recursive fields,
+        // If the `recursive.fields.max.depth` value is not specified, it will 
default to -1,
+        // and recursive fields are not permitted. Setting it to 0 drops all 
recursive fields,
         // 1 allows it to be recursed once, and 2 allows it to be recursed 
twice and so on.
         // A value greater than 10 is not allowed, and if a protobuf record 
has more depth for
         // recursive fields than the allowed value, it will be truncated and 
some fields may be
         // discarded.
-        // SQL Schema for the protobuf message `message Person { string name = 
1; Person bff = 2}`
+        // SQL Schema for protob2uf `message Person { string name = 1; Person 
bff = 2;}`
         // will vary based on the value of "recursive.fields.max.depth".
-        // 1: struct<name: string, bff: null>
-        // 2: struct<name string, bff: <name: string, bff: null>>
-        // 3: struct<name string, bff: <name: string, bff: struct<name: 
string, bff: null>>> ...
+        // 1: struct<name: string>
+        // 2: struct<name string, bff: struct<name: string>>
+        // 3: struct<name string, bff: struct<name string, bff: struct<name: 
string>>>
+        // and so on.
+        // TODO(rangadi): A better way to terminate would be replace the 
remaining recursive struct
+        //      with the byte array of corresponding protobuf. This way no 
information is lost.
+        //      i.e. with max depth 2, the above looks like this:
+        //      struct<name: string, bff: struct<name: string, 
_serialized_bff: bytes>>
         val recordName = fd.getMessageType.getFullName
         val recursiveDepth = existingRecordNames.getOrElse(recordName, 0)
         val recursiveFieldMaxDepth = protobufOptions.recursiveFieldMaxDepth
@@ -129,23 +142,36 @@ object SchemaConverters {
           throw 
QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
         } else if (existingRecordNames.contains(recordName) &&
           recursiveDepth >= recursiveFieldMaxDepth) {
-          Some(NullType)
+          // Recursive depth limit is reached. This field is dropped.
+          // If it is inside a container like map or array, the containing 
field is dropped.
+          log.info(
+            s"The field ${fd.getFullName} of type $recordName is dropped " +
+              s"at recursive depth $recursiveDepth"
+          )
+          None
         } else {
           val newRecordNames = existingRecordNames + (recordName -> 
(recursiveDepth + 1))
-          Option(
-            fd.getMessageType.getFields.asScala
-              .flatMap(structFieldFor(_, newRecordNames, protobufOptions))
-              .toSeq)
-            .filter(_.nonEmpty)
-            .map(StructType.apply)
+          val fields = fd.getMessageType.getFields.asScala.flatMap(
+            structFieldFor(_, newRecordNames, protobufOptions)
+          ).toSeq
+          fields match {
+            case Nil =>
+              log.info(
+                s"Dropping ${fd.getFullName} as it does not have any fields 
left " +
+                "likely due to recursive depth limit."
+              )
+              None
+            case fds => Some(StructType(fds))
+          }
         }
       case other =>
         throw 
QueryCompilationErrors.protobufTypeUnsupportedYetError(other.toString)
     }
-    dataType.map(dt =>
-      StructField(
-        fd.getName,
-        if (fd.isRepeated) ArrayType(dt, containsNull = false) else dt,
-        nullable = !fd.isRequired && !fd.isRepeated))
+    dataType.map {
+      case dt: MapType => StructField(fd.getName, dt)
+      case dt if fd.isRepeated =>
+        StructField(fd.getName, ArrayType(dt, containsNull = false))
+      case dt => StructField(fd.getName, dt, nullable = !fd.isRequired)
+    }
   }
 }
diff --git 
a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc 
b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc
index d16f8935080..467b9cac969 100644
Binary files 
a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc and 
b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc differ
diff --git 
a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto 
b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
index a0698ee3979..d83ba6a4f6e 100644
--- a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
+++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
@@ -233,6 +233,19 @@ message EventPersonWrapper {
   EventPerson person = 1;
 }
 
+message PersonWithRecursiveArray {
+  // A protobuf with recursive repeated field
+  string name = 1;
+  repeated PersonWithRecursiveArray friends = 2;
+}
+
+message PersonWithRecursiveMap {
+  // A protobuf with recursive field in value
+  string name = 1;
+  map<string, PersonWithRecursiveMap> groups = 3;
+}
+
+
 message OneOfEventWithRecursion {
   string key = 1;
   oneof payload {
@@ -243,14 +256,26 @@ message OneOfEventWithRecursion {
 }
 
 message EventRecursiveA {
-  OneOfEventWithRecursion recursiveA = 1;
+  OneOfEventWithRecursion recursiveOneOffInA = 1;
   string key = 2;
 }
 
 message EventRecursiveB {
   string key = 1;
   string value = 2;
-  OneOfEventWithRecursion recursiveA = 3;
+  OneOfEventWithRecursion recursiveOneOffInB = 3;
+}
+
+message EmptyRecursiveProto {
+  // This is a recursive proto with no fields. Used to test edge. Catalyst 
schema for this
+  // should be "nothing" (i.e. completely dropped) irrespective of recursive 
limit.
+  EmptyRecursiveProto recursive_field = 1;
+  repeated EmptyRecursiveProto recursive_array = 2;
+}
+
+message EmptyRecursiveProtoWrapper {
+  string name = 1;
+  EmptyRecursiveProto empty_recursive = 2; // This field will be dropped.
 }
 
 message Status {
diff --git 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
index 60e13644fc6..92c3c27bfae 100644
--- 
a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
+++ 
b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -19,17 +19,17 @@ package org.apache.spark.sql.protobuf
 import java.sql.Timestamp
 import java.time.Duration
 
-import scala.collection.JavaConverters._
+ import scala.collection.JavaConverters._
 
 import com.google.protobuf.{ByteString, DynamicMessage}
 
 import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, 
Row}
 import org.apache.spark.sql.functions.{lit, struct}
-import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{EM, EM2, 
Employee, EventPerson, EventPersonWrapper, EventRecursiveA, EventRecursiveB, 
IC, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated}
+import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos._
 import 
org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum
 import org.apache.spark.sql.protobuf.utils.ProtobufUtils
 import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, IntegerType, 
StringType, StructField, StructType, TimestampType}
+import org.apache.spark.sql.types._
 
 class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with 
ProtobufTestBase
   with Serializable {
@@ -56,10 +56,15 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
 
   // A wrapper to invoke the right variable of from_protobuf() depending on 
arguments.
   private def from_protobuf_wrapper(
-    col: Column, messageName: String, descFilePathOpt: Option[String]): Column 
= {
+    col: Column,
+    messageName: String,
+    descFilePathOpt: Option[String],
+    options: Map[String, String] = Map.empty): Column = {
     descFilePathOpt match {
-      case Some(descFilePath) => functions.from_protobuf(col, messageName, 
descFilePath)
-      case None => functions.from_protobuf(col, messageName)
+      case Some(descFilePath) => functions.from_protobuf(
+        col, messageName, descFilePath, options.asJava
+      )
+      case None => functions.from_protobuf(col, messageName, options.asJava)
     }
   }
 
@@ -72,7 +77,6 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
     }
   }
 
-
   test("roundtrip in to_protobuf and from_protobuf - struct") {
     val df = spark
       .range(1, 10)
@@ -352,7 +356,7 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
     }
   }
 
-  test("roundtrip in from_protobuf and to_protobuf - Multiple Message") {
+  test("round trip in from_protobuf and to_protobuf - Multiple Message") {
     val messageMultiDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"MultipleExample")
     val messageIncludeDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"IncludedExample")
     val messageOtherDesc = ProtobufUtils.buildDescriptor(testFileDesc, 
"OtherExample")
@@ -385,10 +389,9 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
           from_protobuf_wrapper($"value_to", name, 
descFilePathOpt).as("value_to_from"))
         checkAnswer(fromProtoDF.select($"value_from.*"), 
toFromProtoDF.select($"value_to_from.*"))
     }
-  }
 
-  test("Recursive fields in Protobuf should result in an error (B -> A -> B)") 
{
-    checkWithFileAndClassName("recursiveB") {
+    // Simple recursion
+    checkWithFileAndClassName("recursiveB") { // B -> A -> B
       case (name, descFilePathOpt) =>
         val e = intercept[AnalysisException] {
           emptyBinaryDF.select(
@@ -702,92 +705,44 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
           assert(expectedFields.contains(f.getName))
         })
 
-        val jsonSchema =
-          s"""
-             |{
-             |  "type" : "struct",
-             |  "fields" : [ {
-             |    "name" : "sample",
-             |    "type" : {
-             |      "type" : "struct",
-             |      "fields" : [ {
-             |        "name" : "key",
-             |        "type" : "string",
-             |        "nullable" : true
-             |      }, {
-             |        "name" : "col_1",
-             |        "type" : "integer",
-             |        "nullable" : true
-             |      }, {
-             |        "name" : "col_2",
-             |        "type" : "string",
-             |        "nullable" : true
-             |      }, {
-             |        "name" : "col_3",
-             |        "type" : "long",
-             |        "nullable" : true
-             |      }, {
-             |        "name" : "col_4",
-             |        "type" : {
-             |          "type" : "array",
-             |          "elementType" : "string",
-             |          "containsNull" : false
-             |        },
-             |        "nullable" : false
-             |      } ]
-             |    },
-             |    "nullable" : true
-             |  } ]
-             |}
-             |{
-             |  "type" : "struct",
-             |  "fields" : [ {
-             |    "name" : "sample",
-             |    "type" : {
-             |      "type" : "struct",
-             |      "fields" : [ {
-             |        "name" : "key",
-             |        "type" : "string",
-             |        "nullable" : true
-             |      }, {
-             |        "name" : "col_1",
-             |        "type" : "integer",
-             |        "nullable" : true
-             |      }, {
-             |        "name" : "col_2",
-             |        "type" : "string",
-             |        "nullable" : true
-             |      }, {
-             |        "name" : "col_3",
-             |        "type" : "long",
-             |        "nullable" : true
-             |      }, {
-             |        "name" : "col_4",
-             |        "type" : {
-             |          "type" : "array",
-             |          "elementType" : "string",
-             |          "containsNull" : false
-             |        },
-             |        "nullable" : false
-             |      } ]
-             |    },
-             |    "nullable" : true
-             |  } ]
-             |}
-             |""".stripMargin
-        val schema = DataType.fromJson(jsonSchema).asInstanceOf[StructType]
-        val data = Seq(Row(Row("key", 123, "col2value", 109202L, 
Seq("col4value"))))
+        val schema = DataType.fromJson(
+          """
+            | {
+            |   "type":"struct",
+            |   "fields":[
+            |     {"name":"sample","nullable":true,"type":{
+            |       "type":"struct",
+            |       "fields":[
+            |         {"name":"key","type":"string","nullable":true},
+            |         {"name":"col_1","type":"integer","nullable":true},
+            |         {"name":"col_2","type":"string","nullable":true},
+            |         {"name":"col_3","type":"long","nullable":true},
+            |         {"name":"col_4","nullable":true,"type":{
+            |           
"type":"array","elementType":"string","containsNull":false}}
+            |       ]}
+            |     }
+            |   ]
+            | }
+            |""".stripMargin).asInstanceOf[StructType]
+        assert(fromProtoDf.schema == schema)
+
+        val data = Seq(
+          Row(Row("key", 123, "col2value", 109202L, Seq("col4value"))),
+          Row(Row("key2", null, null, null, null)) // Leave the rest null, 
including "col_4" array.
+        )
         val dataDf = 
spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
         val dataDfToProto = dataDf.select(
           to_protobuf_wrapper($"sample", name, descFilePathOpt) as 'toProto)
 
-        val eventFromSparkSchema = OneOfEvent.parseFrom(
-          
dataDfToProto.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
+        val toProtoResults = dataDfToProto.select("toProto").collect()
+        val eventFromSparkSchema = 
OneOfEvent.parseFrom(toProtoResults(0).getAs[Array[Byte]](0))
         assert(eventFromSparkSchema.getCol2.isEmpty)
         assert(eventFromSparkSchema.getCol3 == 109202L)
         eventFromSparkSchema.getDescriptorForType.getFields.asScala.map(f => {
           assert(expectedFields.contains(f.getName))
         })
+        val secondEventFromSpark = 
OneOfEvent.parseFrom(toProtoResults(1).getAs[Array[Byte]](0))
+        assert(secondEventFromSpark.getKey == "key2")
     }
   }
 
@@ -853,13 +808,13 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
       .setKey("keyNested2").setValue("valueNested2").build()
     val nestedOne = EventRecursiveA.newBuilder()
       .setKey("keyNested1")
-      .setRecursiveA(nestedTwo).build()
+      .setRecursiveOneOffInA(nestedTwo).build()
     val oneOfRecursionEvent = OneOfEventWithRecursion.newBuilder()
       .setKey("keyNested0")
       .setValue("valueNested0")
       .setRecursiveA(nestedOne).build()
     val recursiveA = EventRecursiveA.newBuilder().setKey("recursiveAKey")
-      .setRecursiveA(oneOfRecursionEvent).build()
+      .setRecursiveOneOffInA(oneOfRecursionEvent).build()
     val recursiveB = EventRecursiveB.newBuilder()
       .setKey("recursiveBKey")
       .setValue("recursiveBvalue").build()
@@ -872,7 +827,7 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
     val df = Seq(oneOfEventWithRecursion.toByteArray).toDF("value")
 
     val options = new java.util.HashMap[String, String]()
-    options.put("recursive.fields.max.depth", "2")
+    options.put("recursive.fields.max.depth", "2") // Recursive fields appear 
twice.
 
     val fromProtoDf = df.select(
       functions.from_protobuf($"value",
@@ -896,177 +851,60 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
     val eventFromSpark = OneOfEventWithRecursion.parseFrom(
       toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
 
-    var recursiveField = eventFromSpark.getRecursiveA.getRecursiveA
+    var recursiveField = eventFromSpark.getRecursiveA.getRecursiveOneOffInA
     assert(recursiveField.getKey.equals("keyNested0"))
     assert(recursiveField.getValue.equals("valueNested0"))
     assert(recursiveField.getRecursiveA.getKey.equals("keyNested1"))
-    assert(recursiveField.getRecursiveA.getRecursiveA.getKey.isEmpty())
+    assert(recursiveField.getRecursiveA.getRecursiveOneOffInA.getKey.isEmpty())
 
     val expectedFields = descriptor.getFields.asScala.map(f => f.getName)
     eventFromSpark.getDescriptorForType.getFields.asScala.map(f => {
       assert(expectedFields.contains(f.getName))
     })
 
-    val jsonSchema =
-      s"""
-         |{
-         |  "type" : "struct",
-         |  "fields" : [ {
-         |    "name" : "sample",
-         |    "type" : {
-         |      "type" : "struct",
-         |      "fields" : [ {
-         |        "name" : "key",
-         |        "type" : "string",
-         |        "nullable" : true
-         |      }, {
-         |        "name" : "recursiveA",
-         |        "type" : {
-         |          "type" : "struct",
-         |          "fields" : [ {
-         |            "name" : "recursiveA",
-         |            "type" : {
-         |              "type" : "struct",
-         |              "fields" : [ {
-         |                "name" : "key",
-         |                "type" : "string",
-         |                "nullable" : true
-         |              }, {
-         |                "name" : "recursiveA",
-         |                "type" : "void",
-         |                "nullable" : true
-         |              }, {
-         |                "name" : "recursiveB",
-         |                "type" : {
-         |                  "type" : "struct",
-         |                  "fields" : [ {
-         |                    "name" : "key",
-         |                    "type" : "string",
-         |                    "nullable" : true
-         |                  }, {
-         |                    "name" : "value",
-         |                    "type" : "string",
-         |                    "nullable" : true
-         |                  }, {
-         |                    "name" : "recursiveA",
-         |                    "type" : {
-         |                      "type" : "struct",
-         |                      "fields" : [ {
-         |                        "name" : "key",
-         |                        "type" : "string",
-         |                        "nullable" : true
-         |                      }, {
-         |                        "name" : "recursiveA",
-         |                        "type" : "void",
-         |                        "nullable" : true
-         |                      }, {
-         |                        "name" : "recursiveB",
-         |                        "type" : "void",
-         |                        "nullable" : true
-         |                      }, {
-         |                        "name" : "value",
-         |                        "type" : "string",
-         |                        "nullable" : true
-         |                      } ]
-         |                    },
-         |                    "nullable" : true
-         |                  } ]
-         |                },
-         |                "nullable" : true
-         |              }, {
-         |                "name" : "value",
-         |                "type" : "string",
-         |                "nullable" : true
-         |              } ]
-         |            },
-         |            "nullable" : true
-         |          }, {
-         |            "name" : "key",
-         |            "type" : "string",
-         |            "nullable" : true
-         |          } ]
-         |        },
-         |        "nullable" : true
-         |      }, {
-         |        "name" : "recursiveB",
-         |        "type" : {
-         |          "type" : "struct",
-         |          "fields" : [ {
-         |            "name" : "key",
-         |            "type" : "string",
-         |            "nullable" : true
-         |          }, {
-         |            "name" : "value",
-         |            "type" : "string",
-         |            "nullable" : true
-         |          }, {
-         |            "name" : "recursiveA",
-         |            "type" : {
-         |              "type" : "struct",
-         |              "fields" : [ {
-         |                "name" : "key",
-         |                "type" : "string",
-         |                "nullable" : true
-         |              }, {
-         |                "name" : "recursiveA",
-         |                "type" : {
-         |                  "type" : "struct",
-         |                  "fields" : [ {
-         |                    "name" : "recursiveA",
-         |                    "type" : {
-         |                      "type" : "struct",
-         |                      "fields" : [ {
-         |                        "name" : "key",
-         |                        "type" : "string",
-         |                        "nullable" : true
-         |                      }, {
-         |                        "name" : "recursiveA",
-         |                        "type" : "void",
-         |                        "nullable" : true
-         |                      }, {
-         |                        "name" : "recursiveB",
-         |                        "type" : "void",
-         |                        "nullable" : true
-         |                      }, {
-         |                        "name" : "value",
-         |                        "type" : "string",
-         |                        "nullable" : true
-         |                      } ]
-         |                    },
-         |                    "nullable" : true
-         |                  }, {
-         |                    "name" : "key",
-         |                    "type" : "string",
-         |                    "nullable" : true
-         |                  } ]
-         |                },
-         |                "nullable" : true
-         |              }, {
-         |                "name" : "recursiveB",
-         |                "type" : "void",
-         |                "nullable" : true
-         |              }, {
-         |                "name" : "value",
-         |                "type" : "string",
-         |                "nullable" : true
-         |              } ]
-         |            },
-         |            "nullable" : true
-         |          } ]
-         |        },
-         |        "nullable" : true
-         |      }, {
-         |        "name" : "value",
-         |        "type" : "string",
-         |        "nullable" : true
-         |      } ]
-         |    },
-         |    "nullable" : true
-         |  } ]
-         |}
-         |""".stripMargin
-
-    val schema = DataType.fromJson(jsonSchema).asInstanceOf[StructType]
+    val schemaDDL =
+      """
+        | -- OneOfEvenWithRecursion with max depth 2.
+        | sample STRUCT< -- 1st level for OneOffWithRecursion
+        |     key string,
+        |     recursiveA STRUCT< -- 1st level for RecursiveA
+        |         recursiveOneOffInA STRUCT< -- 2st level for 
OneOffWithRecursion
+        |             key string,
+        |             recursiveA STRUCT< -- 2st level for RecursiveA
+        |                 key string
+        |                 -- Removed recursiveOneOffInA: 3rd level for 
OneOffWithRecursion
+        |             >,
+        |             recursiveB STRUCT<
+        |                 key string,
+        |                 value string
+        |                 -- Removed recursiveOneOffInB: 3rd level for 
OneOffWithRecursion
+        |             >,
+        |             value string
+        |         >,
+        |         key string
+        |     >,
+        |     recursiveB STRUCT< -- 1st level for RecursiveB
+        |         key string,
+        |         value string,
+        |         recursiveOneOffInB STRUCT< -- 2st level for 
OneOffWithRecursion
+        |             key string,
+        |             recursiveA STRUCT< -- 1st level for RecursiveA
+        |                 key string
+        |                 -- Removed recursiveOneOffInA: 3rd level for 
OneOffWithRecursion
+        |             >,
+        |             recursiveB STRUCT<
+        |                 key string,
+        |                 value string
+        |                 -- Removed recursiveOneOffInB: 3rd level for 
OneOffWithRecursion
+        |             >,
+        |             value string
+        |         >
+        |     >,
+        |     value string
+        | >
+        |""".stripMargin
+    val schema = structFromDDL(schemaDDL)
+    assert(fromProtoDf.schema == schema)
     val data = Seq(
       Row(
         Row("key1",
@@ -1083,7 +921,7 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
 
     val eventFromSparkSchema = OneOfEventWithRecursion.parseFrom(
       dataDfToProto.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
-    recursiveField = eventFromSparkSchema.getRecursiveA.getRecursiveA
+    recursiveField = eventFromSparkSchema.getRecursiveA.getRecursiveOneOffInA
     assert(recursiveField.getKey.equals("keyNested0"))
     assert(recursiveField.getValue.equals("valueNested0"))
     assert(recursiveField.getRecursiveA.getKey.isEmpty())
@@ -1101,166 +939,61 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
 
     val optionsZero = new java.util.HashMap[String, String]()
     optionsZero.put("recursive.fields.max.depth", "1")
-    val schemaZero = DataType.fromJson(
-        s"""{
-           |  "type" : "struct",
-           |  "fields" : [ {
-           |    "name" : "sample",
-           |    "type" : {
-           |      "type" : "struct",
-           |      "fields" : [ {
-           |        "name" : "name",
-           |        "type" : "string",
-           |        "nullable" : true
-           |      }, {
-           |        "name" : "bff",
-           |        "type" : "void",
-           |        "nullable" : true
-           |      } ]
-           |    },
-           |    "nullable" : true
-           |  } ]
-           |}""".stripMargin).asInstanceOf[StructType]
-    val expectedDfZero = spark.createDataFrame(
-      spark.sparkContext.parallelize(Seq(Row(Row("person0", null)))), 
schemaZero)
-
-    testFromProtobufWithOptions(df, expectedDfZero, optionsZero, "EventPerson")
-
-    val optionsOne = new java.util.HashMap[String, String]()
-    optionsOne.put("recursive.fields.max.depth", "2")
-    val schemaOne = DataType.fromJson(
-      s"""{
-         |  "type" : "struct",
-         |  "fields" : [ {
-         |    "name" : "sample",
-         |    "type" : {
-         |      "type" : "struct",
-         |      "fields" : [ {
-         |        "name" : "name",
-         |        "type" : "string",
-         |        "nullable" : true
-         |      }, {
-         |        "name" : "bff",
-         |        "type" : {
-         |          "type" : "struct",
-         |          "fields" : [ {
-         |            "name" : "name",
-         |            "type" : "string",
-         |            "nullable" : true
-         |          }, {
-         |            "name" : "bff",
-         |            "type" : "void",
-         |            "nullable" : true
-         |          } ]
-         |        },
-         |        "nullable" : true
-         |      } ]
-         |    },
-         |    "nullable" : true
-         |  } ]
-         |}""".stripMargin).asInstanceOf[StructType]
+    val schemaOne = structFromDDL(
+      "sample STRUCT<name: STRING>" // 'bff' field is dropped to due to limit 
of 1.
+    )
     val expectedDfOne = spark.createDataFrame(
-      spark.sparkContext.parallelize(Seq(Row(Row("person0", Row("person1", 
null))))), schemaOne)
-    testFromProtobufWithOptions(df, expectedDfOne, optionsOne, "EventPerson")
+      spark.sparkContext.parallelize(Seq(Row(Row("person0", null)))), 
schemaOne)
+    testFromProtobufWithOptions(df, expectedDfOne, optionsZero, "EventPerson")
 
     val optionsTwo = new java.util.HashMap[String, String]()
-    optionsTwo.put("recursive.fields.max.depth", "3")
-    val schemaTwo = DataType.fromJson(
-      s"""{
-         |  "type" : "struct",
-         |  "fields" : [ {
-         |    "name" : "sample",
-         |    "type" : {
-         |      "type" : "struct",
-         |      "fields" : [ {
-         |        "name" : "name",
-         |        "type" : "string",
-         |        "nullable" : true
-         |      }, {
-         |        "name" : "bff",
-         |        "type" : {
-         |          "type" : "struct",
-         |          "fields" : [ {
-         |            "name" : "name",
-         |            "type" : "string",
-         |            "nullable" : true
-         |          }, {
-         |            "name" : "bff",
-         |            "type" : {
-         |              "type" : "struct",
-         |              "fields" : [ {
-         |                "name" : "name",
-         |                "type" : "string",
-         |                "nullable" : true
-         |              }, {
-         |                "name" : "bff",
-         |                "type" : "void",
-         |                "nullable" : true
-         |              } ]
-         |            },
-         |            "nullable" : true
-         |          } ]
-         |        },
-         |        "nullable" : true
-         |      } ]
-         |    },
-         |    "nullable" : true
-         |  } ]
-         |}""".stripMargin).asInstanceOf[StructType]
-    val expectedDfTwo = spark.createDataFrame(spark.sparkContext.parallelize(
-      Seq(Row(Row("person0", Row("person1", Row("person2", null)))))), 
schemaTwo)
+    optionsTwo.put("recursive.fields.max.depth", "2")
+    val schemaTwo = structFromDDL(
+      """
+        | sample STRUCT<
+        |     name: STRING,
+        |     bff: STRUCT<name: STRING> -- Recursion is terminated here.
+        | >
+        |""".stripMargin)
+    val expectedDfTwo = spark.createDataFrame(
+      spark.sparkContext.parallelize(Seq(Row(Row("person0", Row("person1", 
null))))), schemaTwo)
     testFromProtobufWithOptions(df, expectedDfTwo, optionsTwo, "EventPerson")
 
+    val optionsThree = new java.util.HashMap[String, String]()
+    optionsThree.put("recursive.fields.max.depth", "3")
+    val schemaThree = structFromDDL(
+      """
+        | sample STRUCT<
+        |     name: STRING,
+        |     bff: STRUCT<
+        |         name: STRING,
+        |         bff: STRUCT<name: STRING>
+        |     >
+        | >
+        |""".stripMargin)
+    val expectedDfThree = spark.createDataFrame(spark.sparkContext.parallelize(
+      Seq(Row(Row("person0", Row("person1", Row("person2", null)))))), 
schemaThree)
+    testFromProtobufWithOptions(df, expectedDfThree, optionsThree, 
"EventPerson")
+
     // Test recursive level 1 with EventPersonWrapper. In this case the top 
level struct
     // 'EventPersonWrapper' itself does not recurse unlike 'EventPerson'.
     // "bff" appears twice: Once allowed recursion and second time as 
terminated "null" type.
-    val wrapperSchemaOne = DataType.fromJson(
+    val wrapperSchemaOne = structFromDDL(
       """
-        |{
-        |  "type" : "struct",
-        |  "fields" : [ {
-        |    "name" : "sample",
-        |    "type" : {
-        |      "type" : "struct",
-        |      "fields" : [ {
-        |        "name" : "person",
-        |        "type" : {
-        |          "type" : "struct",
-        |          "fields" : [ {
-        |            "name" : "name",
-        |            "type" : "string",
-        |            "nullable" : true
-        |          }, {
-        |            "name" : "bff",
-        |            "type" : {
-        |              "type" : "struct",
-        |              "fields" : [ {
-        |                "name" : "name",
-        |                "type" : "string",
-        |                "nullable" : true
-        |              }, {
-        |                "name" : "bff",
-        |                "type" : "void",
-        |                "nullable" : true
-        |              } ]
-        |            },
-        |            "nullable" : true
-        |          } ]
-        |        },
-        |        "nullable" : true
-        |      } ]
-        |    },
-        |    "nullable" : true
-        |  } ]
-        |}
+        | sample STRUCT<
+        |     person: STRUCT< -- 1st level
+        |         name: STRING,
+        |         bff: STRUCT<name: STRING> -- 2nd level. Inner 3rd level 
Person is dropped.
+        |     >
+        | >
         |""".stripMargin).asInstanceOf[StructType]
-    val expectedWrapperDfOne = spark.createDataFrame(
+    val expectedWrapperDfTwo = spark.createDataFrame(
       spark.sparkContext.parallelize(Seq(Row(Row(Row("person0", Row("person1", 
null)))))),
       wrapperSchemaOne)
     testFromProtobufWithOptions(
       
Seq(EventPersonWrapper.newBuilder().setPerson(eventPerson0).build().toByteArray).toDF(),
-      expectedWrapperDfOne,
-      optionsOne,
+      expectedWrapperDfTwo,
+      optionsTwo,
       "EventPersonWrapper"
     )
   }
@@ -1287,6 +1020,92 @@ class ProtobufFunctionsSuite extends QueryTest with 
SharedSparkSession with Prot
     assert(ex.getCause.getMessage.matches(".*No such file.*"), 
ex.getCause.getMessage())
   }
 
+  test("Recursive fields in arrays and maps") {
+    // Verifies schema for recursive proto in an array field & map field.
+    val options = Map("recursive.fields.max.depth" -> "3")
+
+    checkWithFileAndClassName("PersonWithRecursiveArray") {
+      case (name, descFilePathOpt) =>
+        val expectedSchema = StructType(
+          // DDL: "proto STRUCT<name: string, friends: array<
+          //    struct<name: string, friends: array<struct<name: string>>>>>"
+          // Can not use DataType.fromDDL(), it does not support 
"containsNull: false" for arrays.
+          StructField("proto",
+            StructType( // 1st level
+              StructField("name", StringType) :: StructField("friends", // 2nd 
level
+                ArrayType(
+                  StructType(StructField("name", StringType) :: 
StructField("friends", // 3rd level
+                    ArrayType(
+                      StructType(StructField("name", StringType) :: Nil), // 
4th, array dropped
+                      containsNull = false)
+                  ):: Nil),
+                  containsNull = false)
+              ) :: Nil
+            )
+          ) :: Nil
+        )
+
+        val df = emptyBinaryDF.select(
+          from_protobuf_wrapper($"binary", name, descFilePathOpt, 
options).as("proto")
+        )
+        assert(df.schema == expectedSchema)
+    }
+
+    checkWithFileAndClassName("PersonWithRecursiveMap") {
+      case (name, descFilePathOpt) =>
+        val expectedSchema = StructType(
+          // DDL: "proto STRUCT<name: string, groups: map<
+          //    struct<name: string, group: map<struct<name: string>>>>>"
+          StructField("proto",
+            StructType( // 1st level
+              StructField("name", StringType) :: StructField("groups", // 2nd 
level
+                MapType(
+                  StringType,
+                  StructType(StructField("name", StringType) :: 
StructField("groups", // 3rd level
+                    MapType(
+                      StringType,
+                      StructType(StructField("name", StringType) :: Nil), // 
4th, array dropped
+                      valueContainsNull = false)
+                  ):: Nil),
+                  valueContainsNull = false)
+              ) :: Nil
+            )
+          ) :: Nil
+        )
+
+        val df = emptyBinaryDF.select(
+          from_protobuf_wrapper($"binary", name, descFilePathOpt, 
options).as("proto")
+        )
+        assert(df.schema == expectedSchema)
+    }
+  }
+
+  test("Corner case: empty recursive proto fields should be dropped") {
+    // This verifies that a empty proto like 'message A { A a = 1}' are 
completely dropped
+    // irrespective of max depth setting.
+
+    val options = Map("recursive.fields.max.depth" -> "4")
+
+    // EmptyRecursiveProto at the top level. It will be an empty struct.
+    checkWithFileAndClassName("EmptyRecursiveProto") {
+      case (name, descFilePathOpt) =>
+          val df = emptyBinaryDF.select(
+            from_protobuf_wrapper($"binary", name, descFilePathOpt, 
options).as("empty_proto")
+          )
+        assert(df.schema == structFromDDL("empty_proto struct<>"))
+    }
+
+    // EmptyRecursiveProto at inner level.
+    checkWithFileAndClassName("EmptyRecursiveProtoWrapper") {
+      case (name, descFilePathOpt) =>
+        val df = emptyBinaryDF.select(
+          from_protobuf_wrapper($"binary", name, descFilePathOpt, 
options).as("wrapper")
+        )
+        // 'empty_recursive' field is dropped from the schema. Only "name" is 
present.
+        assert(df.schema == structFromDDL("wrapper struct<name: string>"))
+    }
+  }
+
   def testFromProtobufWithOptions(
     df: DataFrame,
     expectedDf: DataFrame,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to