Kimahriman commented on a change in pull request #32448:
URL: https://github.com/apache/spark/pull/32448#discussion_r631516151



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
##########
@@ -483,8 +483,8 @@ case class StructType(fields: Array[StructField]) extends 
DataType with Seq[Stru
    * 4. Otherwise, `this` and `that` are considered as conflicting schemas and 
an exception would be
    *    thrown.
    */
-  private[sql] def merge(that: StructType): StructType =
-    StructType.merge(this, that).asInstanceOf[StructType]
+  private[sql] def merge(that: StructType, resolver: Resolver = _ == _): 
StructType =

Review comment:
       Done

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala
##########
@@ -21,136 +21,55 @@ import scala.collection.mutable
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.optimizer.{CombineUnions, 
OptimizeUpdateFields}
+import org.apache.spark.sql.catalyst.optimizer.{CombineUnions}
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, 
Union}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.AlwaysProcess
+import org.apache.spark.sql.errors._
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SchemaUtils
-import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * Resolves different children of Union to a common set of columns.
  */
 object ResolveUnion extends Rule[LogicalPlan] {
-  /**
-   * This method sorts columns recursively in a struct expression based on 
column names.
-   */
-  private def sortStructFields(expr: Expression): Expression = {
-    val existingExprs = 
expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
-      case (name, i) =>
-        val fieldExpr = GetStructField(KnownNotNull(expr), i)
-        if (fieldExpr.dataType.isInstanceOf[StructType]) {
-          (name, sortStructFields(fieldExpr))
-        } else {
-          (name, fieldExpr)
-        }
-    }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2))
-
-    val newExpr = CreateNamedStruct(existingExprs)
-    if (expr.nullable) {
-      If(IsNull(expr), Literal(null, newExpr.dataType), newExpr)
-    } else {
-      newExpr
-    }
-  }
-
-  /**
-   * Assumes input expressions are field expression of `CreateNamedStruct`. 
This method
-   * sorts the expressions based on field names.
-   */
-  private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = {
-    fieldExprs.grouped(2).map { e =>
-      Seq(e.head, e.last)
-    }.toSeq.sortBy { pair =>
-      assert(pair.head.isInstanceOf[Literal])
-      pair.head.eval().asInstanceOf[UTF8String].toString
-    }.flatten
-  }
-
-  /**
-   * This helper method sorts fields in a `UpdateFields` expression by field 
name.
-   */
-  private def sortStructFieldsInWithFields(expr: Expression): Expression = 
expr transformUp {
-    case u: UpdateFields if u.resolved =>
-      u.evalExpr match {
-        case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) =>
-          val sorted = sortFieldExprs(fieldExprs)
-          val newStruct = CreateNamedStruct(sorted)
-          i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = 
newStruct)
-        case CreateNamedStruct(fieldExprs) =>
-          val sorted = sortFieldExprs(fieldExprs)
-          val newStruct = CreateNamedStruct(sorted)
-          newStruct
-        case other =>
-          throw new IllegalStateException(s"`UpdateFields` has incorrect 
expression: $other. " +
-            "Please file a bug report with this error message, stack trace, 
and the query.")
-      }
-  }
-
   /**
    * Adds missing fields recursively into given `col` expression, based on the 
target `StructType`.
    * This is called by `compareAndAddFields` when we find two struct columns 
with same name but
    * different nested fields. This method will find out the missing nested 
fields from `col` to
    * `target` struct and add these missing nested fields. Currently we don't 
support finding out
    * missing nested fields of struct nested in array or struct nested in map.
    */
-  private def addFields(col: NamedExpression, target: StructType): Expression 
= {
+  private def addFields(col: Expression, expectedFields: Seq[StructField]): 
Expression = {
     assert(col.dataType.isInstanceOf[StructType], "Only support StructType.")
 
     val resolver = conf.resolver
-    val missingFieldsOpt =
-      StructType.findMissingFields(col.dataType.asInstanceOf[StructType], 
target, resolver)
-
-    // We need to sort columns in result, because we might add another column 
in other side.
-    // E.g., we want to union two structs "a int, b long" and "a int, c 
string".
-    // If we don't sort, we will have "a int, b long, c string" and
-    // "a int, c string, b long", which are not compatible.
-    if (missingFieldsOpt.isEmpty) {
-      sortStructFields(col)
-    } else {
-      missingFieldsOpt.map { s =>
-        val struct = addFieldsInto(col, s.fields)
-        // Combines `WithFields`s to reduce expression tree.
-        val reducedStruct = 
struct.transformUp(OptimizeUpdateFields.optimizeUpdateFields)
-        val sorted = sortStructFieldsInWithFields(reducedStruct)
-        sorted
-      }.get
-    }
-  }
-
-  /**
-   * Adds missing fields recursively into given `col` expression. The missing 
fields are given
-   * in `fields`. For example, given `col` as "z struct<z:int, y:int>, x int", 
and `fields` is
-   * "z struct<w:long>, w string". This method will add a nested `z.w` field 
and a top-level
-   * `w` field to `col` and fill null values for them. Note that because we 
might also add missing
-   * fields at other side of Union, we must make sure corresponding attributes 
at two sides have
-   * same field order in structs, so when we adding missing fields, we will 
sort the fields based on
-   * field names. So the data type of returned expression will be
-   * "w string, x int, z struct<w:long, y:int, z:int>".
-   */
-  private def addFieldsInto(
-      col: Expression,
-      fields: Seq[StructField]): Expression = {
-    fields.foldLeft(col) { case (currCol, field) =>
-      field.dataType match {
-        case st: StructType =>
-          val resolver = conf.resolver
-          val colField = currCol.dataType.asInstanceOf[StructType]
-            .find(f => resolver(f.name, field.name))
-          if (colField.isEmpty) {
-            // The whole struct is missing. Add a null.
-            UpdateFields(currCol, field.name, Literal(null, st))
-          } else {
-            UpdateFields(currCol, field.name,
-              addFieldsInto(ExtractValue(currCol, Literal(field.name), 
resolver), st.fields))
-          }
-        case dt =>
-          UpdateFields(currCol, field.name, Literal(null, dt))
+    val colType = col.dataType.asInstanceOf[StructType]
+    val newStructFields = expectedFields.flatMap(expectedField => {

Review comment:
       Done

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala
##########
@@ -21,136 +21,55 @@ import scala.collection.mutable
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.optimizer.{CombineUnions, 
OptimizeUpdateFields}
+import org.apache.spark.sql.catalyst.optimizer.{CombineUnions}
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, 
Union}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.AlwaysProcess
+import org.apache.spark.sql.errors._
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SchemaUtils
-import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * Resolves different children of Union to a common set of columns.
  */
 object ResolveUnion extends Rule[LogicalPlan] {
-  /**
-   * This method sorts columns recursively in a struct expression based on 
column names.
-   */
-  private def sortStructFields(expr: Expression): Expression = {
-    val existingExprs = 
expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
-      case (name, i) =>
-        val fieldExpr = GetStructField(KnownNotNull(expr), i)
-        if (fieldExpr.dataType.isInstanceOf[StructType]) {
-          (name, sortStructFields(fieldExpr))
-        } else {
-          (name, fieldExpr)
-        }
-    }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2))
-
-    val newExpr = CreateNamedStruct(existingExprs)
-    if (expr.nullable) {
-      If(IsNull(expr), Literal(null, newExpr.dataType), newExpr)
-    } else {
-      newExpr
-    }
-  }
-
-  /**
-   * Assumes input expressions are field expression of `CreateNamedStruct`. 
This method
-   * sorts the expressions based on field names.
-   */
-  private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = {
-    fieldExprs.grouped(2).map { e =>
-      Seq(e.head, e.last)
-    }.toSeq.sortBy { pair =>
-      assert(pair.head.isInstanceOf[Literal])
-      pair.head.eval().asInstanceOf[UTF8String].toString
-    }.flatten
-  }
-
-  /**
-   * This helper method sorts fields in a `UpdateFields` expression by field 
name.
-   */
-  private def sortStructFieldsInWithFields(expr: Expression): Expression = 
expr transformUp {
-    case u: UpdateFields if u.resolved =>
-      u.evalExpr match {
-        case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) =>
-          val sorted = sortFieldExprs(fieldExprs)
-          val newStruct = CreateNamedStruct(sorted)
-          i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = 
newStruct)
-        case CreateNamedStruct(fieldExprs) =>
-          val sorted = sortFieldExprs(fieldExprs)
-          val newStruct = CreateNamedStruct(sorted)
-          newStruct
-        case other =>
-          throw new IllegalStateException(s"`UpdateFields` has incorrect 
expression: $other. " +
-            "Please file a bug report with this error message, stack trace, 
and the query.")
-      }
-  }
-
   /**
    * Adds missing fields recursively into given `col` expression, based on the 
target `StructType`.
    * This is called by `compareAndAddFields` when we find two struct columns 
with same name but
    * different nested fields. This method will find out the missing nested 
fields from `col` to
    * `target` struct and add these missing nested fields. Currently we don't 
support finding out
    * missing nested fields of struct nested in array or struct nested in map.
    */
-  private def addFields(col: NamedExpression, target: StructType): Expression 
= {
+  private def addFields(col: Expression, expectedFields: Seq[StructField]): 
Expression = {
     assert(col.dataType.isInstanceOf[StructType], "Only support StructType.")
 
     val resolver = conf.resolver
-    val missingFieldsOpt =
-      StructType.findMissingFields(col.dataType.asInstanceOf[StructType], 
target, resolver)
-
-    // We need to sort columns in result, because we might add another column 
in other side.
-    // E.g., we want to union two structs "a int, b long" and "a int, c 
string".
-    // If we don't sort, we will have "a int, b long, c string" and
-    // "a int, c string, b long", which are not compatible.
-    if (missingFieldsOpt.isEmpty) {
-      sortStructFields(col)
-    } else {
-      missingFieldsOpt.map { s =>
-        val struct = addFieldsInto(col, s.fields)
-        // Combines `WithFields`s to reduce expression tree.
-        val reducedStruct = 
struct.transformUp(OptimizeUpdateFields.optimizeUpdateFields)
-        val sorted = sortStructFieldsInWithFields(reducedStruct)
-        sorted
-      }.get
-    }
-  }
-
-  /**
-   * Adds missing fields recursively into given `col` expression. The missing 
fields are given
-   * in `fields`. For example, given `col` as "z struct<z:int, y:int>, x int", 
and `fields` is
-   * "z struct<w:long>, w string". This method will add a nested `z.w` field 
and a top-level
-   * `w` field to `col` and fill null values for them. Note that because we 
might also add missing
-   * fields at other side of Union, we must make sure corresponding attributes 
at two sides have
-   * same field order in structs, so when we adding missing fields, we will 
sort the fields based on
-   * field names. So the data type of returned expression will be
-   * "w string, x int, z struct<w:long, y:int, z:int>".
-   */
-  private def addFieldsInto(
-      col: Expression,
-      fields: Seq[StructField]): Expression = {
-    fields.foldLeft(col) { case (currCol, field) =>
-      field.dataType match {
-        case st: StructType =>
-          val resolver = conf.resolver
-          val colField = currCol.dataType.asInstanceOf[StructType]
-            .find(f => resolver(f.name, field.name))
-          if (colField.isEmpty) {
-            // The whole struct is missing. Add a null.
-            UpdateFields(currCol, field.name, Literal(null, st))
-          } else {
-            UpdateFields(currCol, field.name,
-              addFieldsInto(ExtractValue(currCol, Literal(field.name), 
resolver), st.fields))
-          }
-        case dt =>
-          UpdateFields(currCol, field.name, Literal(null, dt))
+    val colType = col.dataType.asInstanceOf[StructType]
+    val newStructFields = expectedFields.flatMap(expectedField => {
+      val currentField = colType.fields.find(f => resolver(f.name, 
expectedField.name))
+
+      val newExpression = (currentField, expectedField.dataType) match {
+        case (Some(cf), expectedType: StructType) if 
cf.dataType.isInstanceOf[StructType] =>
+            val extractedValue = ExtractValue(col, Literal(cf.name), resolver)
+            val combinedStruct = addFields(extractedValue, expectedType.fields)
+            if (extractedValue.nullable) {
+              If(IsNull(extractedValue),
+                Literal(null, combinedStruct.dataType),
+                combinedStruct)
+            } else {
+              combinedStruct
+            }
+        case (Some(cf), _) =>
+          ExtractValue(col, Literal(cf.name), resolver)
+        case (_, expectedType) =>

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to