viirya commented on a change in pull request #33728:
URL: https://github.com/apache/spark/pull/33728#discussion_r689244868



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
##########
@@ -74,6 +66,160 @@ object TableOutputResolver {
     }
   }
 
+  private def reorderColumnsByName(
+      inputCols: Seq[NamedExpression],
+      expectedCols: Seq[Attribute],
+      conf: SQLConf,
+      addError: String => Unit,
+      colPath: Seq[String] = Nil): Seq[NamedExpression] = {
+    val matchedCols = mutable.HashSet.empty[String]
+    val reordered = expectedCols.flatMap { expectedCol =>
+      val matched = inputCols.filter(col => conf.resolver(col.name, 
expectedCol.name))
+      val newColPath = colPath :+ expectedCol.name
+      if (matched.isEmpty) {
+        addError(s"Cannot find data for output column '${newColPath.quoted}'")
+        None
+      } else if (matched.length > 1) {
+        addError(s"Ambiguous column name in the input data: 
'${newColPath.quoted}'")
+        None
+      } else {
+        matchedCols += matched.head.name
+        val expectedName = expectedCol.name
+        val matchedCol = matched.head match {
+          // Save an Alias if we can change the name directly.
+          case a: Attribute => a.withName(expectedName)
+          case a: Alias => a.withName(expectedName)
+          case other => other
+        }
+        (matchedCol.dataType, expectedCol.dataType) match {
+          case (matchedType: StructType, expectedType: StructType) =>
+            checkNullability(matchedCol, expectedCol, conf, addError, 
newColPath)
+            resolveStructType(
+              matchedCol, matchedType, expectedType, expectedName, conf, 
addError, newColPath)
+          case (matchedType: ArrayType, expectedType: ArrayType) =>
+            checkNullability(matchedCol, expectedCol, conf, addError, 
newColPath)
+            resolveArrayType(
+              matchedCol, matchedType, expectedType, expectedName, conf, 
addError, newColPath)
+          case (matchedType: MapType, expectedType: MapType) =>
+            checkNullability(matchedCol, expectedCol, conf, addError, 
newColPath)
+            resolveMapType(
+              matchedCol, matchedType, expectedType, expectedName, conf, 
addError, newColPath)
+          case _ =>
+            checkField(expectedCol, matchedCol, byName = true, conf, addError)
+        }
+      }
+    }
+
+    if (reordered.length == expectedCols.length) {
+      if (matchedCols.size < inputCols.length) {
+        val extraCols = inputCols.filterNot(col => 
matchedCols.contains(col.name))
+          .map(col => s"'${col.name}'").mkString(", ")
+        addError(s"Cannot write extra fields to struct '${colPath.quoted}': 
$extraCols")
+        Nil
+      } else {
+        reordered
+      }
+    } else {
+      Nil
+    }
+  }
+
+  private def checkNullability(
+      input: Expression,
+      expected: Attribute,
+      conf: SQLConf,
+      addError: String => Unit,
+      colPath: Seq[String]): Unit = {
+    if (input.nullable && !expected.nullable &&
+      conf.storeAssignmentPolicy != StoreAssignmentPolicy.LEGACY) {
+      addError(s"Cannot write nullable values to non-null column 
'${colPath.quoted}'")
+    }
+  }
+
+  private def resolveStructType(
+      input: NamedExpression,
+      inputType: StructType,
+      expectedType: StructType,
+      expectedName: String,
+      conf: SQLConf,
+      addError: String => Unit,
+      colPath: Seq[String]): Option[NamedExpression] = {
+    val fields = inputType.zipWithIndex.map { case (f, i) =>
+      Alias(GetStructField(input, i, Some(f.name)), f.name)()
+    }
+    val reordered = reorderColumnsByName(fields, expectedType.toAttributes, 
conf, addError, colPath)
+    if (reordered.length == expectedType.length) {
+      val struct = CreateStruct(reordered)
+      val res = if (input.nullable) {
+        If(IsNull(input), Literal(null, struct.dataType), struct)
+      } else {
+        struct
+      }
+      Some(Alias(res, expectedName)())
+    } else {
+      None
+    }
+  }
+
+  private def resolveArrayType(
+      input: NamedExpression,
+      inputType: ArrayType,
+      expectedType: ArrayType,
+      expectedName: String,
+      conf: SQLConf,
+      addError: String => Unit,
+      colPath: Seq[String]): Option[NamedExpression] = {
+    if (inputType.containsNull && !expectedType.containsNull) {
+      addError(s"Cannot write nullable elements to array of non-nulls: 
'${colPath.quoted}'")
+      None
+    } else {
+      val param = NamedLambdaVariable("x", inputType.elementType, 
inputType.containsNull)
+      val fakeAttr = AttributeReference("x", expectedType.elementType, 
expectedType.containsNull)()
+      val res = reorderColumnsByName(Seq(param), Seq(fakeAttr), conf, 
addError, colPath)
+      if (res.length == 1) {
+        val func = LambdaFunction(res.head, Seq(param))
+        Some(Alias(ArrayTransform(input, func), expectedName)())
+      } else {
+        None
+      }
+    }
+  }
+
+  private def resolveMapType(
+      input: NamedExpression,
+      inputType: MapType,
+      expectedType: MapType,
+      expectedName: String,
+      conf: SQLConf,
+      addError: String => Unit,
+      colPath: Seq[String]): Option[NamedExpression] = {
+    if (inputType.valueContainsNull && !expectedType.valueContainsNull) {
+      addError(s"Cannot write nullable values to map of non-nulls: 
'${colPath.quoted}'")
+      None
+    } else {
+      val keyParam = NamedLambdaVariable("k", inputType.keyType, nullable = 
false)
+      val fakeKeyAttr = AttributeReference("k", expectedType.keyType, nullable 
= false)()
+      val resKey = reorderColumnsByName(
+        Seq(keyParam), Seq(fakeKeyAttr), conf, addError, colPath :+ "key")
+
+      val valueParam = NamedLambdaVariable("v", inputType.valueType, 
inputType.valueContainsNull)
+      val fakeValueAttr =
+        AttributeReference("v", expectedType.valueType, 
expectedType.valueContainsNull)()
+      val resValue = reorderColumnsByName(
+        Seq(valueParam), Seq(fakeValueAttr), conf, addError, colPath :+ 
"value")
+
+      if (resKey.length == 1 && resValue.length == 1) {
+        val keyFunc = LambdaFunction(resKey.head, Seq(keyParam))
+        val valueFunc = LambdaFunction(resValue.head, Seq(valueParam))
+        val newKeys = ArrayTransform(MapKeys(input), keyFunc)
+        val newValues = ArrayTransform(MapValues(input), valueFunc)
+        Some(Alias(MapFromArrays(newKeys, newValues), expectedName)())

Review comment:
       ok




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to