cloud-fan commented on code in PR #40308:
URL: https://github.com/apache/spark/pull/40308#discussion_r1132482968


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala:
##########
@@ -0,0 +1,275 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, CreateNamedStruct, Expression, ExtractValue, 
GetStructField, Literal, NamedExpression}
+import org.apache.spark.sql.catalyst.plans.logical.Assignment
+import org.apache.spark.sql.catalyst.util.CharVarcharUtils
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
+import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
+
+object AssignmentUtils extends SQLConfHelper with CastSupport {
+
+  private case class ColumnUpdate(ref: Seq[String], expr: Expression)
+
+  /**
+   * Aligns assignments to match table columns.
+   * <p>
+   * This method processes and reorders given assignments so that each target 
column gets
+   * an expression it should be set to. If a column does not have a matching 
assignment,
+   * it will be set to its current value. For example, if one passes table 
attributes c1, c2
+   * and an assignment c2 = 1, this method will return c1 = c1, c2 = 1.
+   * <p>
+   * This method also handles updates to nested columns. If there is an 
assignment to a particular
+   * nested field, this method will construct a new struct with one field 
updated preserving other
+   * fields that have not been modified. For example, if one passes table 
attributes c1, c2
+   * where c2 is a struct with fields n1 and n2 and an assignment c2.n2 = 1, 
this method will
+   * return c1 = c1, c2 = struct(c2.n1, 1).
+   *
+   * @param attrs table attributes
+   * @param assignments assignments to align
+   * @return aligned assignments that match table columns
+   */
+  def alignAssignments(
+      attrs: Seq[Attribute],
+      assignments: Seq[Assignment]): Seq[Assignment] = {
+
+    val errors = new mutable.ArrayBuffer[String]()
+
+    val output = applyUpdates(
+      updates = assignments.map(toColumnUpdate),
+      cols = attrs.map(restoreActualType),
+      colExprs = attrs,
+      addError = err => errors += err)
+
+    if (errors.nonEmpty) {
+      throw 
QueryCompilationErrors.invalidRowLevelOperationAssignments(assignments, 
errors.toSeq)
+    }
+
+    attrs.zip(output).map { case (attr, expr) => Assignment(attr, expr) }
+  }
+
+  private def toColumnUpdate(assignment: Assignment): ColumnUpdate = {
+    ColumnUpdate(toRef(assignment.key), assignment.value)
+  }
+
+  private def restoreActualType(attr: Attribute): Attribute = {
+    
attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType))
+  }
+
+  private def applyUpdates(
+      updates: Seq[ColumnUpdate],
+      cols: Seq[Attribute],
+      colExprs: Seq[Expression],
+      addError: String => Unit,
+      colPath: Seq[String] = Nil): Seq[Expression] = {
+
+    // iterate through columns at the current level and find matching updates
+    cols.zip(colExprs).map { case (col, colExpr) =>
+      // find matches for this column or any of its children
+      val prefixMatchedUpdates = updates.filter(update => 
conf.resolver(update.ref.head, col.name))
+      prefixMatchedUpdates match {
+        // if there is no exact match and no match for children, return the 
column expr as is
+        case matchedUpdates if matchedUpdates.isEmpty =>
+          colExpr
+
+        // if there is only one update and it is an exact match, return the 
assigned expression
+        case Seq(matchedUpdate) if isExactMatch(matchedUpdate, col) =>
+          applyUpdate(matchedUpdate.expr, col, addError, colPath :+ col.name)
+
+        // if there are matches only for children
+        case matchedUpdates if !hasExactMatch(matchedUpdates, col) =>
+          val newColPath = colPath :+ col.name
+          col.dataType match {
+            case colType: StructType =>
+              // build field expressions
+              val fieldExprs = colType.fields.zipWithIndex.map { case (field, 
ordinal) =>
+                GetStructField(col, ordinal, Some(field.name))
+              }
+
+              // recursively apply this method to nested fields
+              val updatedFieldExprs = applyUpdates(
+                matchedUpdates.map(update => update.copy(ref = 
update.ref.tail)),
+                colType.toAttributes,
+                fieldExprs,
+                addError,
+                newColPath)
+
+              // construct a new struct with updated field expressions
+              toNamedStruct(colType, updatedFieldExprs)
+
+            case otherType =>
+              addError(
+                "Updating nested fields is only supported for StructType but " 
+
+                s"${newColPath.quoted} is of type $otherType")
+              col
+          }
+
+        // if there are conflicting updates, throw an exception
+        // there are two illegal scenarios:
+        // - multiple updates to the same column
+        // - updates to a top-level struct and its nested fields (like a.b and 
a.b.c)
+        case matchedUpdates if hasExactMatch(matchedUpdates, col) =>
+          val conflictingColNames = matchedUpdates.map(update => (colPath ++ 
update.ref).quoted)
+          addError("Update conflicts for columns: " + 
conflictingColNames.distinct.mkString(", "))
+          col
+      }
+    }
+  }
+
+  private def toNamedStruct(structType: StructType, fieldExprs: 
Seq[Expression]): Expression = {
+    val namedStructExprs = structType.fields.zip(fieldExprs).flatMap { case 
(field, expr) =>
+      Seq(Literal(field.name), expr)
+    }
+    CreateNamedStruct(namedStructExprs)
+  }
+
+  private def hasExactMatch(updates: Seq[ColumnUpdate], col: NamedExpression): 
Boolean = {
+    updates.exists(isExactMatch(_, col))
+  }
+
+  private def isExactMatch(update: ColumnUpdate, col: NamedExpression): 
Boolean = {
+    update.ref match {
+      case Seq(namePart) if conf.resolver(namePart, col.name) => true
+      case _ => false
+    }
+  }
+
+  private def applyUpdate(
+      value: Expression,
+      col: Attribute,
+      addError: String => Unit,
+      colPath: Seq[String]): Expression = {
+
+    (value.dataType, col.dataType) match {
+      // no need to reorder inner fields or cast if types are equal ignoring 
nullability
+      case (valueType, colType) if valueType.sameType(colType) =>
+        TableOutputResolver.checkNullability(value, col, conf, addError, 
colPath)
+        validateAssignment(valueType, colType, addError, colPath)
+        value
+
+      case (valueType: StructType, colType: StructType) =>
+        TableOutputResolver.checkNullability(value, col, conf, addError, 
colPath)
+        val resolvedValue = TableOutputResolver.resolveStructType(
+          value, valueType, colType,
+          conf, addError, colPath)
+        resolvedValue.getOrElse(col)
+
+      case (valueType: ArrayType, colType: ArrayType) =>
+        TableOutputResolver.checkNullability(value, col, conf, addError, 
colPath)
+        val resolvedValue = TableOutputResolver.resolveArrayType(
+          value, valueType, colType,
+          conf, addError, colPath)
+        resolvedValue.getOrElse(col)
+
+      case (valueType: MapType, colType: MapType) =>
+        TableOutputResolver.checkNullability(value, col, conf, addError, 
colPath)
+        val resolvedValue = TableOutputResolver.resolveMapType(
+          value, valueType, colType,
+          conf, addError, colPath)
+        resolvedValue.getOrElse(col)
+
+      case (valueType, colType) =>
+        TableOutputResolver.checkNullability(value, col, conf, addError, 
colPath)
+
+        val colTypeHasCharVarchar = CharVarcharUtils.hasCharVarchar(colType)
+        val colTypeWithoutCharVarchar = if (colTypeHasCharVarchar) {
+          CharVarcharUtils.replaceCharVarcharWithString(colType)
+        } else {
+          colType
+        }
+
+        validateAssignment(valueType, colTypeWithoutCharVarchar, addError, 
colPath)
+
+        val casted = TableOutputResolver.cast(
+          value, colTypeWithoutCharVarchar,
+          conf, colPath.quoted)
+
+        if (conf.charVarcharAsString || !colTypeHasCharVarchar) {
+          casted
+        } else {
+          CharVarcharUtils.stringLengthCheck(casted, colType)
+        }
+    }
+  }
+
+  private def validateAssignment(
+      valueType: DataType,
+      expectedType: DataType,
+      addError: String => Unit,
+      colPath: Seq[String]): Unit = {
+
+    conf.storeAssignmentPolicy match {
+      case StoreAssignmentPolicy.STRICT | StoreAssignmentPolicy.ANSI =>
+        DataType.canWrite(
+          valueType, expectedType, byName = true, conf.resolver, 
colPath.quoted,
+          conf.storeAssignmentPolicy, addError)
+
+      case _ => // OK
+    }
+  }
+
+  /**
+   * Checks whether assignments are aligned and are compatible with table 
columns.
+   *
+   * @param attrs table attributes
+   * @param assignments assignments to check
+   * @return true if the assignments are aligned
+   */
+  def aligned(attrs: Seq[Attribute], assignments: Seq[Assignment]): Boolean = {
+    if (attrs.size != assignments.size) {
+      return false
+    }
+
+    attrs.zip(assignments).forall { case (attr, assignment) =>
+      val key = assignment.key
+      val value = assignment.value
+
+      val attrType = 
CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType)
+
+      sameRef(toRef(key), toRef(attr)) &&
+        DataType.equalsIgnoreCompatibleNullability(value.dataType, attrType) &&
+        (attr.nullable || !value.nullable)
+    }
+  }
+
+  private def sameRef(ref: Seq[String], otherRef: Seq[String]): Boolean = {
+    ref.size == otherRef.size && ref.zip(otherRef).forall { case (namePart, 
otherNamePart) =>
+      conf.resolver(namePart, otherNamePart)
+    }
+  }
+
+  private def toRef(expr: Expression): Seq[String] = expr match {
+    case attr: AttributeReference =>
+      Seq(attr.name)
+    case Alias(child, _) =>
+      toRef(child)
+    case GetStructField(child, _, Some(name)) =>
+      toRef(child) :+ name
+    case other: ExtractValue =>

Review Comment:
   how about `GetArrayStructField`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to