This is an automated email from the ASF dual-hosted git repository. viirya pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5241d98 [SPARK-36546][SQL] Add array support to union by name 5241d98 is described below commit 5241d9880036c43ff29f7a995e190026349bf838 Author: Adam Binford <adam...@gmail.com> AuthorDate: Wed Oct 13 19:13:01 2021 -0700 [SPARK-36546][SQL] Add array support to union by name ### What changes were proposed in this pull request? This PR adds array of struct support to unionByName. It further tries to simplify some of the logic for re-projecting different types by creating a `mergeFields` method that can find the right method to merge various combinations of fields, currently structs or arrays, but could add maps in the future. ### Why are the changes needed? Currently unionByName doesn't support arrays of structs or maps of structs. This adds support for the arrays of structs and should make it easy to add maps of structs support in the future. ### Does this PR introduce _any_ user-facing change? New capability to unionByName with arrays of structs. ### How was this patch tested? New unit tests Closes #34246 from Kimahriman/union-by-name-array. Authored-by: Adam Binford <adam...@gmail.com> Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com> --- .../spark/sql/catalyst/analysis/ResolveUnion.scala | 71 ++++-- .../main/scala/org/apache/spark/sql/Dataset.scala | 10 +- .../spark/sql/DataFrameSetOperationsSuite.scala | 272 +++++++++++++++++++++ 3 files changed, 326 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 0d805c5..fff38bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -33,12 +33,27 @@ import org.apache.spark.sql.util.SchemaUtils */ object ResolveUnion extends Rule[LogicalPlan] { /** + * Transform the array of structs to the target struct type. + */ + private def transformArray(arrayCol: Expression, targetType: ArrayType, + allowMissing: Boolean) = { + assert(arrayCol.dataType.isInstanceOf[ArrayType], "Only support ArrayType.") + + val arrayType = arrayCol.dataType.asInstanceOf[ArrayType] + + val x = NamedLambdaVariable(UnresolvedNamedLambdaVariable.freshVarName("x"), + arrayType.elementType, + arrayType.containsNull) + val function = mergeFields(x, targetType.elementType, allowMissing) + ArrayTransform(arrayCol, LambdaFunction(function, Seq(x))) + } + + /** * Adds missing fields recursively into given `col` expression, based on the expected struct * fields from merging the two schemas. This is called by `compareAndAddFields` when we find two * struct columns with same name but different nested fields. This method will recursively * return a new struct with all of the expected fields, adding null values when `col` doesn't - * already contain them. Currently we don't support merging structs nested inside of arrays - * or maps. + * already contain them. Currently we don't support merging structs nested inside of maps. */ private def addFields(col: Expression, targetType: StructType, allowMissing: Boolean): Expression = { @@ -53,12 +68,8 @@ object ResolveUnion extends Rule[LogicalPlan] { val currentField = colType.fields.find(f => resolver(f.name, expectedField.name)) val newExpression = (currentField, expectedField.dataType) match { - case (Some(cf), expectedType: StructType) if cf.dataType.isInstanceOf[StructType] - && !DataType.equalsStructurallyByName(cf.dataType, expectedType, resolver) => - val extractedValue = ExtractValue(col, Literal(cf.name), resolver) - addFields(extractedValue, expectedType, allowMissing) - case (Some(cf), _) => - ExtractValue(col, Literal(cf.name), resolver) + case (Some(cf), expectedType) => + mergeFields(ExtractValue(col, Literal(cf.name), resolver), expectedType, allowMissing) case (None, expectedType) => if (allowMissing) { // for allowMissingCol allow the null values @@ -87,6 +98,26 @@ object ResolveUnion extends Rule[LogicalPlan] { } /** + * Handles the merging of complex types. Currently supports structs and arrays recursively. + */ + private def mergeFields(col: Expression, targetType: DataType, + allowMissing: Boolean): Expression = { + if (!DataType.equalsStructurallyByName(col.dataType, targetType, conf.resolver)) { + (col.dataType, targetType) match { + case (_: StructType, targetStruct: StructType) => + addFields(col, targetStruct, allowMissing) + case (_: ArrayType, targetArray: ArrayType) => + transformArray(col, targetArray, allowMissing) + case _ => + // Unsupported combination, let the resulting union analyze + col + } + } else { + col + } + } + + /** * This method will compare right to left plan's outputs. If there is one struct attribute * at right side has same name with left side struct attribute, but two structs are not the * same data type, i.e., some missing (nested) fields at right struct attribute, then this @@ -107,22 +138,14 @@ object ResolveUnion extends Rule[LogicalPlan] { if (found.isDefined) { val foundAttr = found.get val foundDt = foundAttr.dataType - (foundDt, lattr.dataType) match { - case (source: StructType, target: StructType) - if !DataType.equalsStructurallyByName(source, target, resolver) => - // We have two structs with different types, so make sure the two structs have their - // fields in the same order by using `target`'s fields and then including any remaining - // in `foundAttr` in case of allowMissingCol is true. - aliased += foundAttr - Alias(addFields(foundAttr, target, allowMissingCol), foundAttr.name)() - case _ => - // We don't need/try to add missing fields if: - // 1. The attributes of left and right side are the same struct type - // 2. The attributes are not struct types. They might be primitive types, or array, map - // types. We don't support adding missing fields of nested structs in array or map - // types now. - // 3. `allowMissingCol` is disabled. - foundAttr + if (!DataType.equalsStructurallyByName(foundDt, lattr.dataType, resolver)) { + // The two types are complex and have different nested structs at some level. + // Map types are currently not supported and will return the existing attribute. + aliased += foundAttr + Alias(mergeFields(foundAttr, lattr.dataType, allowMissingCol), foundAttr.name)() + } else { + // Either both sides are primitive types or equivalent complex types + foundAttr } } else { if (allowMissingCol) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 22e914e..c8cdc20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2115,6 +2115,9 @@ class Dataset[T] private[sql]( * // +----+----+----+ * }}} * + * Note that this supports nested columns in struct and array types. Nested columns in map types + * are not currently supported. + * * @group typedrel * @since 2.3.0 */ @@ -2155,9 +2158,10 @@ class Dataset[T] private[sql]( * // +----+----+----+----+ * }}} * - * Note that `allowMissingColumns` supports nested column in struct types. Missing nested columns - * of struct columns with the same name will also be filled with null values and added to the end - * of struct. This currently does not support nested columns in array and map types. + * Note that this supports nested columns in struct and array types. With `allowMissingColumns`, + * missing nested columns of struct columns with the same name will also be filled with null + * values and added to the end of struct. Nested columns in map types are not currently + * supported. * * @group typedrel * @since 3.1.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 4e00de0..650d878 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -1083,6 +1083,278 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { assert(err.message .contains("Union can only be performed on tables with the compatible column types")) } + + test("SPARK-36546: Add unionByName support to arrays of structs") { + val arrayType1 = ArrayType( + StructType(Seq( + StructField("ba", StringType), + StructField("bb", StringType) + )) + ) + val arrayValues1 = Seq(Row("ba", "bb")) + + val arrayType2 = ArrayType( + StructType(Seq( + StructField("bb", StringType), + StructField("ba", StringType) + )) + ) + val arrayValues2 = Seq(Row("bb", "ba")) + + val df1 = spark.createDataFrame( + sparkContext.parallelize(Row(arrayValues1) :: Nil), + StructType(Seq(StructField("arr", arrayType1)))) + + val df2 = spark.createDataFrame( + sparkContext.parallelize(Row(arrayValues2) :: Nil), + StructType(Seq(StructField("arr", arrayType2)))) + + var unionDf = df1.unionByName(df2) + assert(unionDf.schema.toDDL == "`arr` ARRAY<STRUCT<`ba`: STRING, `bb`: STRING>>") + checkAnswer(unionDf, + Row(Seq(Row("ba", "bb"))) :: + Row(Seq(Row("ba", "bb"))) :: Nil) + + unionDf = df2.unionByName(df1) + assert(unionDf.schema.toDDL == "`arr` ARRAY<STRUCT<`bb`: STRING, `ba`: STRING>>") + checkAnswer(unionDf, + Row(Seq(Row("bb", "ba"))) :: + Row(Seq(Row("bb", "ba"))) :: Nil) + + val arrayType3 = ArrayType( + StructType(Seq( + StructField("ba", StringType) + )) + ) + val arrayValues3 = Seq(Row("ba")) + + val arrayType4 = ArrayType( + StructType(Seq( + StructField("bb", StringType) + )) + ) + val arrayValues4 = Seq(Row("bb")) + + val df3 = spark.createDataFrame( + sparkContext.parallelize(Row(arrayValues3) :: Nil), + StructType(Seq(StructField("arr", arrayType3)))) + + val df4 = spark.createDataFrame( + sparkContext.parallelize(Row(arrayValues4) :: Nil), + StructType(Seq(StructField("arr", arrayType4)))) + + assertThrows[AnalysisException] { + df3.unionByName(df4) + } + + unionDf = df3.unionByName(df4, true) + assert(unionDf.schema.toDDL == "`arr` ARRAY<STRUCT<`ba`: STRING, `bb`: STRING>>") + checkAnswer(unionDf, + Row(Seq(Row("ba", null))) :: + Row(Seq(Row(null, "bb"))) :: Nil) + + assertThrows[AnalysisException] { + df4.unionByName(df3) + } + + unionDf = df4.unionByName(df3, true) + assert(unionDf.schema.toDDL == "`arr` ARRAY<STRUCT<`bb`: STRING, `ba`: STRING>>") + checkAnswer(unionDf, + Row(Seq(Row("bb", null))) :: + Row(Seq(Row(null, "ba"))) :: Nil) + } + + test("SPARK-36546: Add unionByName support to nested arrays of structs") { + val nestedStructType1 = StructType(Seq( + StructField("b", ArrayType( + StructType(Seq( + StructField("ba", StringType), + StructField("bb", StringType) + )) + )) + )) + val nestedStructValues1 = Row(Seq(Row("ba", "bb"))) + + val nestedStructType2 = StructType(Seq( + StructField("b", ArrayType( + StructType(Seq( + StructField("bb", StringType), + StructField("ba", StringType) + )) + )) + )) + val nestedStructValues2 = Row(Seq(Row("bb", "ba"))) + + val df1 = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues1) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType1)))) + + val df2 = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues2) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType2)))) + + var unionDf = df1.unionByName(df2) + assert(unionDf.schema.toDDL == "`topLevelCol` " + + "STRUCT<`b`: ARRAY<STRUCT<`ba`: STRING, `bb`: STRING>>>") + checkAnswer(unionDf, + Row(Row(Seq(Row("ba", "bb")))) :: + Row(Row(Seq(Row("ba", "bb")))) :: Nil) + + unionDf = df2.unionByName(df1) + assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<" + + "`b`: ARRAY<STRUCT<`bb`: STRING, `ba`: STRING>>>") + checkAnswer(unionDf, + Row(Row(Seq(Row("bb", "ba")))) :: + Row(Row(Seq(Row("bb", "ba")))) :: Nil) + + val nestedStructType3 = StructType(Seq( + StructField("b", ArrayType( + StructType(Seq( + StructField("ba", StringType) + )) + )) + )) + val nestedStructValues3 = Row(Seq(Row("ba"))) + + val nestedStructType4 = StructType(Seq( + StructField("b", ArrayType( + StructType(Seq( + StructField("bb", StringType) + )) + )) + )) + val nestedStructValues4 = Row(Seq(Row("bb"))) + + val df3 = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues3) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType3)))) + + val df4 = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues4) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType4)))) + + assertThrows[AnalysisException] { + df3.unionByName(df4) + } + + unionDf = df3.unionByName(df4, true) + assert(unionDf.schema.toDDL == "`topLevelCol` " + + "STRUCT<`b`: ARRAY<STRUCT<`ba`: STRING, `bb`: STRING>>>") + checkAnswer(unionDf, + Row(Row(Seq(Row("ba", null)))) :: + Row(Row(Seq(Row(null, "bb")))) :: Nil) + + assertThrows[AnalysisException] { + df4.unionByName(df3) + } + + unionDf = df4.unionByName(df3, true) + assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<" + + "`b`: ARRAY<STRUCT<`bb`: STRING, `ba`: STRING>>>") + checkAnswer(unionDf, + Row(Row(Seq(Row("bb", null)))) :: + Row(Row(Seq(Row(null, "ba")))) :: Nil) + } + + test("SPARK-36546: Add unionByName support to multiple levels of nested arrays of structs") { + val nestedStructType1 = StructType(Seq( + StructField("b", ArrayType( + ArrayType( + StructType(Seq( + StructField("ba", StringType), + StructField("bb", StringType) + )) + ) + )) + )) + val nestedStructValues1 = Row(Seq(Seq(Row("ba", "bb")))) + + val nestedStructType2 = StructType(Seq( + StructField("b", ArrayType( + ArrayType( + StructType(Seq( + StructField("bb", StringType), + StructField("ba", StringType) + )) + ) + )) + )) + val nestedStructValues2 = Row(Seq(Seq(Row("bb", "ba")))) + + val df1 = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues1) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType1)))) + + val df2 = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues2) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType2)))) + + var unionDf = df1.unionByName(df2) + assert(unionDf.schema.toDDL == "`topLevelCol` " + + "STRUCT<`b`: ARRAY<ARRAY<STRUCT<`ba`: STRING, `bb`: STRING>>>>") + checkAnswer(unionDf, + Row(Row(Seq(Seq(Row("ba", "bb"))))) :: + Row(Row(Seq(Seq(Row("ba", "bb"))))) :: Nil) + + unionDf = df2.unionByName(df1) + assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<" + + "`b`: ARRAY<ARRAY<STRUCT<`bb`: STRING, `ba`: STRING>>>>") + checkAnswer(unionDf, + Row(Row(Seq(Seq(Row("bb", "ba"))))) :: + Row(Row(Seq(Seq(Row("bb", "ba"))))) :: Nil) + + val nestedStructType3 = StructType(Seq( + StructField("b", ArrayType( + ArrayType( + StructType(Seq( + StructField("ba", StringType) + )) + ) + )) + )) + val nestedStructValues3 = Row(Seq(Seq(Row("ba")))) + + val nestedStructType4 = StructType(Seq( + StructField("b", ArrayType( + ArrayType( + StructType(Seq( + StructField("bb", StringType) + )) + ) + )) + )) + val nestedStructValues4 = Row(Seq(Seq(Row("bb")))) + + val df3 = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues3) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType3)))) + + val df4 = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues4) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType4)))) + + assertThrows[AnalysisException] { + df3.unionByName(df4) + } + + unionDf = df3.unionByName(df4, true) + assert(unionDf.schema.toDDL == "`topLevelCol` " + + "STRUCT<`b`: ARRAY<ARRAY<STRUCT<`ba`: STRING, `bb`: STRING>>>>") + checkAnswer(unionDf, + Row(Row(Seq(Seq(Row("ba", null))))) :: + Row(Row(Seq(Seq(Row(null, "bb"))))) :: Nil) + + assertThrows[AnalysisException] { + df4.unionByName(df3) + } + + unionDf = df4.unionByName(df3, true) + assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<" + + "`b`: ARRAY<ARRAY<STRUCT<`bb`: STRING, `ba`: STRING>>>>") + checkAnswer(unionDf, + Row(Row(Seq(Seq(Row("bb", null))))) :: + Row(Row(Seq(Seq(Row(null, "ba"))))) :: Nil) + } } case class UnionClass1a(a: Int, b: Long, nested: UnionClass2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org