This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5241d98  [SPARK-36546][SQL] Add array support to union by name
5241d98 is described below

commit 5241d9880036c43ff29f7a995e190026349bf838
Author: Adam Binford <adam...@gmail.com>
AuthorDate: Wed Oct 13 19:13:01 2021 -0700

    [SPARK-36546][SQL] Add array support to union by name
    
    ### What changes were proposed in this pull request?
    
    This PR adds array of struct support to unionByName. It further tries to 
simplify some of the logic for re-projecting different types by creating a 
`mergeFields` method that can find the right method to merge various 
combinations of fields, currently structs or arrays, but could add maps in the 
future.
    
    ### Why are the changes needed?
    
    Currently unionByName doesn't support arrays of structs or maps of structs. 
This adds support for the arrays of structs and should make it easy to add maps 
of structs support in the future.
    
    ### Does this PR introduce _any_ user-facing change?
    
    New capability to unionByName with arrays of structs.
    
    ### How was this patch tested?
    
    New unit tests
    
    Closes #34246 from Kimahriman/union-by-name-array.
    
    Authored-by: Adam Binford <adam...@gmail.com>
    Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com>
---
 .../spark/sql/catalyst/analysis/ResolveUnion.scala |  71 ++++--
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  10 +-
 .../spark/sql/DataFrameSetOperationsSuite.scala    | 272 +++++++++++++++++++++
 3 files changed, 326 insertions(+), 27 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala
index 0d805c5..fff38bb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala
@@ -33,12 +33,27 @@ import org.apache.spark.sql.util.SchemaUtils
  */
 object ResolveUnion extends Rule[LogicalPlan] {
   /**
+   * Transform the array of structs to the target struct type.
+   */
+  private def transformArray(arrayCol: Expression, targetType: ArrayType,
+      allowMissing: Boolean) = {
+    assert(arrayCol.dataType.isInstanceOf[ArrayType], "Only support 
ArrayType.")
+
+    val arrayType = arrayCol.dataType.asInstanceOf[ArrayType]
+
+    val x = 
NamedLambdaVariable(UnresolvedNamedLambdaVariable.freshVarName("x"),
+      arrayType.elementType,
+      arrayType.containsNull)
+    val function = mergeFields(x, targetType.elementType, allowMissing)
+    ArrayTransform(arrayCol, LambdaFunction(function, Seq(x)))
+  }
+
+  /**
    * Adds missing fields recursively into given `col` expression, based on the 
expected struct
    * fields from merging the two schemas. This is called by 
`compareAndAddFields` when we find two
    * struct columns with same name but different nested fields. This method 
will recursively
    * return a new struct with all of the expected fields, adding null values 
when `col` doesn't
-   * already contain them. Currently we don't support merging structs nested 
inside of arrays
-   * or maps.
+   * already contain them. Currently we don't support merging structs nested 
inside of maps.
    */
   private def addFields(col: Expression,
      targetType: StructType, allowMissing: Boolean): Expression = {
@@ -53,12 +68,8 @@ object ResolveUnion extends Rule[LogicalPlan] {
       val currentField = colType.fields.find(f => resolver(f.name, 
expectedField.name))
 
       val newExpression = (currentField, expectedField.dataType) match {
-        case (Some(cf), expectedType: StructType) if 
cf.dataType.isInstanceOf[StructType]
-            && !DataType.equalsStructurallyByName(cf.dataType, expectedType, 
resolver) =>
-          val extractedValue = ExtractValue(col, Literal(cf.name), resolver)
-          addFields(extractedValue, expectedType, allowMissing)
-        case (Some(cf), _) =>
-          ExtractValue(col, Literal(cf.name), resolver)
+        case (Some(cf), expectedType) =>
+          mergeFields(ExtractValue(col, Literal(cf.name), resolver), 
expectedType, allowMissing)
         case (None, expectedType) =>
           if (allowMissing) {
             // for allowMissingCol allow the null values
@@ -87,6 +98,26 @@ object ResolveUnion extends Rule[LogicalPlan] {
   }
 
   /**
+   * Handles the merging of complex types. Currently supports structs and 
arrays recursively.
+   */
+  private def mergeFields(col: Expression, targetType: DataType,
+      allowMissing: Boolean): Expression = {
+    if (!DataType.equalsStructurallyByName(col.dataType, targetType, 
conf.resolver)) {
+      (col.dataType, targetType) match {
+        case (_: StructType, targetStruct: StructType) =>
+          addFields(col, targetStruct, allowMissing)
+        case (_: ArrayType, targetArray: ArrayType) =>
+          transformArray(col, targetArray, allowMissing)
+        case _ =>
+          // Unsupported combination, let the resulting union analyze
+          col
+      }
+    } else {
+      col
+    }
+  }
+
+  /**
    * This method will compare right to left plan's outputs. If there is one 
struct attribute
    * at right side has same name with left side struct attribute, but two 
structs are not the
    * same data type, i.e., some missing (nested) fields at right struct 
attribute, then this
@@ -107,22 +138,14 @@ object ResolveUnion extends Rule[LogicalPlan] {
       if (found.isDefined) {
         val foundAttr = found.get
         val foundDt = foundAttr.dataType
-        (foundDt, lattr.dataType) match {
-          case (source: StructType, target: StructType)
-              if !DataType.equalsStructurallyByName(source, target, resolver) 
=>
-            // We have two structs with different types, so make sure the two 
structs have their
-            // fields in the same order by using `target`'s fields and then 
including any remaining
-            // in `foundAttr` in case of allowMissingCol is true.
-            aliased += foundAttr
-            Alias(addFields(foundAttr, target, allowMissingCol), 
foundAttr.name)()
-          case _ =>
-            // We don't need/try to add missing fields if:
-            // 1. The attributes of left and right side are the same struct 
type
-            // 2. The attributes are not struct types. They might be primitive 
types, or array, map
-            //    types. We don't support adding missing fields of nested 
structs in array or map
-            //    types now.
-            // 3. `allowMissingCol` is disabled.
-            foundAttr
+        if (!DataType.equalsStructurallyByName(foundDt, lattr.dataType, 
resolver)) {
+          // The two types are complex and have different nested structs at 
some level.
+          // Map types are currently not supported and will return the 
existing attribute.
+          aliased += foundAttr
+          Alias(mergeFields(foundAttr, lattr.dataType, allowMissingCol), 
foundAttr.name)()
+        } else {
+          // Either both sides are primitive types or equivalent complex types
+          foundAttr
         }
       } else {
         if (allowMissingCol) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 22e914e..c8cdc20 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2115,6 +2115,9 @@ class Dataset[T] private[sql](
    *   // +----+----+----+
    * }}}
    *
+   * Note that this supports nested columns in struct and array types. Nested 
columns in map types
+   * are not currently supported.
+   *
    * @group typedrel
    * @since 2.3.0
    */
@@ -2155,9 +2158,10 @@ class Dataset[T] private[sql](
    *   // +----+----+----+----+
    * }}}
    *
-   * Note that `allowMissingColumns` supports nested column in struct types. 
Missing nested columns
-   * of struct columns with the same name will also be filled with null values 
and added to the end
-   * of struct. This currently does not support nested columns in array and 
map types.
+   * Note that this supports nested columns in struct and array types. With 
`allowMissingColumns`,
+   * missing nested columns of struct columns with the same name will also be 
filled with null
+   * values and added to the end of struct. Nested columns in map types are 
not currently
+   * supported.
    *
    * @group typedrel
    * @since 3.1.0
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala
index 4e00de0..650d878 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala
@@ -1083,6 +1083,278 @@ class DataFrameSetOperationsSuite extends QueryTest 
with SharedSparkSession {
     assert(err.message
       .contains("Union can only be performed on tables with the compatible 
column types"))
   }
+
+  test("SPARK-36546: Add unionByName support to arrays of structs") {
+    val arrayType1 = ArrayType(
+      StructType(Seq(
+        StructField("ba", StringType),
+        StructField("bb", StringType)
+      ))
+    )
+    val arrayValues1 = Seq(Row("ba", "bb"))
+
+    val arrayType2 = ArrayType(
+      StructType(Seq(
+        StructField("bb", StringType),
+        StructField("ba", StringType)
+      ))
+    )
+    val arrayValues2 = Seq(Row("bb", "ba"))
+
+    val df1 = spark.createDataFrame(
+      sparkContext.parallelize(Row(arrayValues1) :: Nil),
+      StructType(Seq(StructField("arr", arrayType1))))
+
+    val df2 = spark.createDataFrame(
+      sparkContext.parallelize(Row(arrayValues2) :: Nil),
+      StructType(Seq(StructField("arr", arrayType2))))
+
+    var unionDf = df1.unionByName(df2)
+    assert(unionDf.schema.toDDL == "`arr` ARRAY<STRUCT<`ba`: STRING, `bb`: 
STRING>>")
+    checkAnswer(unionDf,
+      Row(Seq(Row("ba", "bb"))) ::
+      Row(Seq(Row("ba", "bb"))) :: Nil)
+
+    unionDf = df2.unionByName(df1)
+    assert(unionDf.schema.toDDL == "`arr` ARRAY<STRUCT<`bb`: STRING, `ba`: 
STRING>>")
+    checkAnswer(unionDf,
+      Row(Seq(Row("bb", "ba"))) ::
+      Row(Seq(Row("bb", "ba"))) :: Nil)
+
+    val arrayType3 = ArrayType(
+      StructType(Seq(
+        StructField("ba", StringType)
+      ))
+    )
+    val arrayValues3 = Seq(Row("ba"))
+
+    val arrayType4 = ArrayType(
+      StructType(Seq(
+        StructField("bb", StringType)
+      ))
+    )
+    val arrayValues4 = Seq(Row("bb"))
+
+    val df3 = spark.createDataFrame(
+      sparkContext.parallelize(Row(arrayValues3) :: Nil),
+      StructType(Seq(StructField("arr", arrayType3))))
+
+    val df4 = spark.createDataFrame(
+      sparkContext.parallelize(Row(arrayValues4) :: Nil),
+      StructType(Seq(StructField("arr", arrayType4))))
+
+    assertThrows[AnalysisException] {
+      df3.unionByName(df4)
+    }
+
+    unionDf = df3.unionByName(df4, true)
+    assert(unionDf.schema.toDDL == "`arr` ARRAY<STRUCT<`ba`: STRING, `bb`: 
STRING>>")
+    checkAnswer(unionDf,
+      Row(Seq(Row("ba", null))) ::
+      Row(Seq(Row(null, "bb"))) :: Nil)
+
+    assertThrows[AnalysisException] {
+      df4.unionByName(df3)
+    }
+
+    unionDf = df4.unionByName(df3, true)
+    assert(unionDf.schema.toDDL == "`arr` ARRAY<STRUCT<`bb`: STRING, `ba`: 
STRING>>")
+    checkAnswer(unionDf,
+      Row(Seq(Row("bb", null))) ::
+      Row(Seq(Row(null, "ba"))) :: Nil)
+  }
+
+  test("SPARK-36546: Add unionByName support to nested arrays of structs") {
+    val nestedStructType1 = StructType(Seq(
+      StructField("b", ArrayType(
+        StructType(Seq(
+          StructField("ba", StringType),
+          StructField("bb", StringType)
+        ))
+      ))
+    ))
+    val nestedStructValues1 = Row(Seq(Row("ba", "bb")))
+
+    val nestedStructType2 = StructType(Seq(
+      StructField("b", ArrayType(
+        StructType(Seq(
+          StructField("bb", StringType),
+          StructField("ba", StringType)
+        ))
+      ))
+    ))
+    val nestedStructValues2 = Row(Seq(Row("bb", "ba")))
+
+    val df1 = spark.createDataFrame(
+      sparkContext.parallelize(Row(nestedStructValues1) :: Nil),
+      StructType(Seq(StructField("topLevelCol", nestedStructType1))))
+
+    val df2 = spark.createDataFrame(
+      sparkContext.parallelize(Row(nestedStructValues2) :: Nil),
+      StructType(Seq(StructField("topLevelCol", nestedStructType2))))
+
+    var unionDf = df1.unionByName(df2)
+    assert(unionDf.schema.toDDL == "`topLevelCol` " +
+      "STRUCT<`b`: ARRAY<STRUCT<`ba`: STRING, `bb`: STRING>>>")
+    checkAnswer(unionDf,
+      Row(Row(Seq(Row("ba", "bb")))) ::
+      Row(Row(Seq(Row("ba", "bb")))) :: Nil)
+
+    unionDf = df2.unionByName(df1)
+    assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<" +
+      "`b`: ARRAY<STRUCT<`bb`: STRING, `ba`: STRING>>>")
+    checkAnswer(unionDf,
+      Row(Row(Seq(Row("bb", "ba")))) ::
+      Row(Row(Seq(Row("bb", "ba")))) :: Nil)
+
+    val nestedStructType3 = StructType(Seq(
+      StructField("b", ArrayType(
+        StructType(Seq(
+          StructField("ba", StringType)
+        ))
+      ))
+    ))
+    val nestedStructValues3 = Row(Seq(Row("ba")))
+
+    val nestedStructType4 = StructType(Seq(
+      StructField("b", ArrayType(
+        StructType(Seq(
+          StructField("bb", StringType)
+        ))
+      ))
+    ))
+    val nestedStructValues4 = Row(Seq(Row("bb")))
+
+    val df3 = spark.createDataFrame(
+      sparkContext.parallelize(Row(nestedStructValues3) :: Nil),
+      StructType(Seq(StructField("topLevelCol", nestedStructType3))))
+
+    val df4 = spark.createDataFrame(
+      sparkContext.parallelize(Row(nestedStructValues4) :: Nil),
+      StructType(Seq(StructField("topLevelCol", nestedStructType4))))
+
+    assertThrows[AnalysisException] {
+      df3.unionByName(df4)
+    }
+
+    unionDf = df3.unionByName(df4, true)
+    assert(unionDf.schema.toDDL == "`topLevelCol` " +
+      "STRUCT<`b`: ARRAY<STRUCT<`ba`: STRING, `bb`: STRING>>>")
+    checkAnswer(unionDf,
+      Row(Row(Seq(Row("ba", null)))) ::
+      Row(Row(Seq(Row(null, "bb")))) :: Nil)
+
+    assertThrows[AnalysisException] {
+      df4.unionByName(df3)
+    }
+
+    unionDf = df4.unionByName(df3, true)
+    assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<" +
+      "`b`: ARRAY<STRUCT<`bb`: STRING, `ba`: STRING>>>")
+    checkAnswer(unionDf,
+      Row(Row(Seq(Row("bb", null)))) ::
+      Row(Row(Seq(Row(null, "ba")))) :: Nil)
+  }
+
+  test("SPARK-36546: Add unionByName support to multiple levels of nested 
arrays of structs") {
+    val nestedStructType1 = StructType(Seq(
+      StructField("b", ArrayType(
+        ArrayType(
+          StructType(Seq(
+            StructField("ba", StringType),
+            StructField("bb", StringType)
+          ))
+        )
+      ))
+    ))
+    val nestedStructValues1 = Row(Seq(Seq(Row("ba", "bb"))))
+
+    val nestedStructType2 = StructType(Seq(
+      StructField("b", ArrayType(
+        ArrayType(
+          StructType(Seq(
+            StructField("bb", StringType),
+            StructField("ba", StringType)
+          ))
+        )
+      ))
+    ))
+    val nestedStructValues2 = Row(Seq(Seq(Row("bb", "ba"))))
+
+    val df1 = spark.createDataFrame(
+      sparkContext.parallelize(Row(nestedStructValues1) :: Nil),
+      StructType(Seq(StructField("topLevelCol", nestedStructType1))))
+
+    val df2 = spark.createDataFrame(
+      sparkContext.parallelize(Row(nestedStructValues2) :: Nil),
+      StructType(Seq(StructField("topLevelCol", nestedStructType2))))
+
+    var unionDf = df1.unionByName(df2)
+    assert(unionDf.schema.toDDL == "`topLevelCol` " +
+      "STRUCT<`b`: ARRAY<ARRAY<STRUCT<`ba`: STRING, `bb`: STRING>>>>")
+    checkAnswer(unionDf,
+      Row(Row(Seq(Seq(Row("ba", "bb"))))) ::
+      Row(Row(Seq(Seq(Row("ba", "bb"))))) :: Nil)
+
+    unionDf = df2.unionByName(df1)
+    assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<" +
+      "`b`: ARRAY<ARRAY<STRUCT<`bb`: STRING, `ba`: STRING>>>>")
+    checkAnswer(unionDf,
+      Row(Row(Seq(Seq(Row("bb", "ba"))))) ::
+      Row(Row(Seq(Seq(Row("bb", "ba"))))) :: Nil)
+
+    val nestedStructType3 = StructType(Seq(
+      StructField("b", ArrayType(
+        ArrayType(
+          StructType(Seq(
+            StructField("ba", StringType)
+          ))
+        )
+      ))
+    ))
+    val nestedStructValues3 = Row(Seq(Seq(Row("ba"))))
+
+    val nestedStructType4 = StructType(Seq(
+      StructField("b", ArrayType(
+        ArrayType(
+          StructType(Seq(
+            StructField("bb", StringType)
+          ))
+        )
+      ))
+    ))
+    val nestedStructValues4 = Row(Seq(Seq(Row("bb"))))
+
+    val df3 = spark.createDataFrame(
+      sparkContext.parallelize(Row(nestedStructValues3) :: Nil),
+      StructType(Seq(StructField("topLevelCol", nestedStructType3))))
+
+    val df4 = spark.createDataFrame(
+      sparkContext.parallelize(Row(nestedStructValues4) :: Nil),
+      StructType(Seq(StructField("topLevelCol", nestedStructType4))))
+
+    assertThrows[AnalysisException] {
+      df3.unionByName(df4)
+    }
+
+    unionDf = df3.unionByName(df4, true)
+    assert(unionDf.schema.toDDL == "`topLevelCol` " +
+      "STRUCT<`b`: ARRAY<ARRAY<STRUCT<`ba`: STRING, `bb`: STRING>>>>")
+    checkAnswer(unionDf,
+      Row(Row(Seq(Seq(Row("ba", null))))) ::
+      Row(Row(Seq(Seq(Row(null, "bb"))))) :: Nil)
+
+    assertThrows[AnalysisException] {
+      df4.unionByName(df3)
+    }
+
+    unionDf = df4.unionByName(df3, true)
+    assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<" +
+      "`b`: ARRAY<ARRAY<STRUCT<`bb`: STRING, `ba`: STRING>>>>")
+    checkAnswer(unionDf,
+      Row(Row(Seq(Seq(Row("bb", null))))) ::
+      Row(Row(Seq(Seq(Row(null, "ba"))))) :: Nil)
+  }
 }
 
 case class UnionClass1a(a: Int, b: Long, nested: UnionClass2)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to