This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new 50f4adf1a4dd [SPARK-54289][SQL] Allow MERGE INTO to preserve existing
struct fields for UPDATE SET * when source struct has less nested fields than
target struct
50f4adf1a4dd is described below
commit 50f4adf1a4ddd27a2e44eb8d4111b9974c8077da
Author: Szehon Ho <[email protected]>
AuthorDate: Fri Nov 21 20:05:19 2025 -0800
[SPARK-54289][SQL] Allow MERGE INTO to preserve existing struct fields for
UPDATE SET * when source struct has less nested fields than target struct
### What changes were proposed in this pull request?
Introduce a new flag spark.sql.merge.nested.type.assign.by.field that
allows UPDATE SET * action in MERGE INTO to be shorthand to assign every nested
struct to its existing source counterpart (ie, UPDATE SET a.b.c =
source.a.b.c). This will have the implication that existing struct field in
the target table that has no source equivalent are preserved, when the
corresponding source struct has less fields than target.
Additional code is added to prevent null expansion in this case (ie, a null
source struct expanding to a struct of nulls).
### Why are the changes needed?
Following https://github.com/apache/spark/pull/52347, we now allow MERGE
INTO to have a source table struct with less nested fields than target table
struct. In this scenario, a user making a UPDATE SET * may have two
interpretations.
The use may interpret UPDATE SET * as shorthand to assign every top-column
level field, ie UPDATE SET struct=source.struct, then the target struct is set
to source struct object as is, with missing fields as NULL. This is the
current behavior.
The user may also mean that UPDATE SET * is short-hand to assign every
nested struct field (ie, UPDATE SET struct.a.b = source.struct.a.b), in which
case the target struct fields missing in source are retained. This is similar
to UPDATE SET * not overriding existing target columns missing in the source,
for example. For this case, this flag is added.
### Does this PR introduce _any_ user-facing change?
No, the support to allow source structs to have less fields than target
structs in MERGE INTO is unreleased yet
(https://github.com/apache/spark/pull/52347), and in any case there is a flag
to toggle this functionality.
### How was this patch tested?
Unit tests, especially around cases where the source struct is null.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53149 from szehon-ho/merge_schema_evolution_update_nested.
Authored-by: Szehon Ho <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
(cherry picked from commit 966e0539c67b8b5ce9219c0bab60656e6ed1364b)
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 13 +-
.../sql/catalyst/analysis/AssignmentUtils.scala | 195 ++++-
.../ResolveRowLevelCommandAssignments.scala | 11 +-
.../catalyst/analysis/RewriteMergeIntoTable.scala | 10 +-
.../ReplaceNullWithFalseInPredicate.scala | 3 +-
.../sql/catalyst/plans/logical/v2Commands.scala | 23 +-
.../spark/sql/catalyst/types/DataTypeUtils.scala | 33 +
.../org/apache/spark/sql/internal/SQLConf.scala | 23 +-
.../sql/connector/MergeIntoTableSuiteBase.scala | 903 ++++++++++++++++++---
.../command/AlignMergeAssignmentsSuite.scala | 28 +-
.../execution/command/PlanResolutionSuite.scala | 40 +-
11 files changed, 1075 insertions(+), 207 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 6b0665c1b7f3..fb8a84a85fc8 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1709,14 +1709,15 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
val resolvedDeleteCondition = deleteCondition.map(
resolveExpressionByPlanChildren(_, m))
DeleteAction(resolvedDeleteCondition)
- case UpdateAction(updateCondition, assignments) =>
+ case UpdateAction(updateCondition, assignments, fromStar) =>
val resolvedUpdateCondition = updateCondition.map(
resolveExpressionByPlanChildren(_, m))
UpdateAction(
resolvedUpdateCondition,
// The update value can access columns from both target and
source tables.
resolveAssignments(assignments, m, MergeResolvePolicy.BOTH,
- throws = throws))
+ throws = throws),
+ fromStar)
case UpdateStarAction(updateCondition) =>
// Expand star to top level source columns. If source has
less columns than target,
// assignments will be added by
ResolveRowLevelCommandAssignments later.
@@ -1738,7 +1739,8 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
updateCondition.map(resolveExpressionByPlanChildren(_, m)),
// For UPDATE *, the value must be from source table.
resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE,
- throws = throws))
+ throws = throws),
+ fromStar = true)
case o => o
}
val newNotMatchedActions = m.notMatchedActions.map {
@@ -1783,14 +1785,15 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
val resolvedDeleteCondition = deleteCondition.map(
resolveExpressionByPlanOutput(_, targetTable))
DeleteAction(resolvedDeleteCondition)
- case UpdateAction(updateCondition, assignments) =>
+ case UpdateAction(updateCondition, assignments, fromStar) =>
val resolvedUpdateCondition = updateCondition.map(
resolveExpressionByPlanOutput(_, targetTable))
UpdateAction(
resolvedUpdateCondition,
// The update value can access columns from the target table
only.
resolveAssignments(assignments, m, MergeResolvePolicy.TARGET,
- throws = throws))
+ throws = throws),
+ fromStar)
case o => o
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
index 145c9077a4c2..6cbc17c67381 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
@@ -21,13 +21,15 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.SQLConfHelper
import
org.apache.spark.sql.catalyst.analysis.TableOutputResolver.DefaultValueFillMode.{NONE,
RECURSE}
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
CreateNamedStruct, Expression, GetStructField, Literal}
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute,
CreateNamedStruct, Expression, GetStructField, If, IsNull, Literal}
+import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical.Assignment
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import
org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLit
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.ArrayImplicits._
@@ -50,13 +52,18 @@ object AssignmentUtils extends SQLConfHelper with
CastSupport {
*
* @param attrs table attributes
* @param assignments assignments to align
+ * @param fromStar whether the assignments were resolved from an UPDATE SET
* clause.
+ * These updates may assign struct fields individually
+ * (preserving existing fields).
* @param coerceNestedTypes whether to coerce nested types to match the
target type
* for complex types
+ * @param missingSourcePaths paths that exist in target but not in source
* @return aligned update assignments that match table attributes
*/
def alignUpdateAssignments(
attrs: Seq[Attribute],
assignments: Seq[Assignment],
+ fromStar: Boolean,
coerceNestedTypes: Boolean): Seq[Assignment] = {
val errors = new mutable.ArrayBuffer[String]()
@@ -68,7 +75,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport
{
assignments,
addError = err => errors += err,
colPath = Seq(attr.name),
- coerceNestedTypes)
+ coerceNestedTypes,
+ fromStar)
}
if (errors.nonEmpty) {
@@ -152,7 +160,8 @@ object AssignmentUtils extends SQLConfHelper with
CastSupport {
assignments: Seq[Assignment],
addError: String => Unit,
colPath: Seq[String],
- coerceNestedTypes: Boolean = false): Expression = {
+ coerceNestedTypes: Boolean = false,
+ updateStar: Boolean = false): Expression = {
val (exactAssignments, otherAssignments) = assignments.partition {
assignment =>
assignment.key.semanticEquals(colExpr)
@@ -174,9 +183,31 @@ object AssignmentUtils extends SQLConfHelper with
CastSupport {
} else if (exactAssignments.isEmpty && fieldAssignments.isEmpty) {
TableOutputResolver.checkNullability(colExpr, col, conf, colPath)
} else if (exactAssignments.nonEmpty) {
- val value = exactAssignments.head.value
- val coerceMode = if (coerceNestedTypes) RECURSE else NONE
- TableOutputResolver.resolveUpdate("", value, col, conf, addError,
colPath, coerceMode)
+ if (SQLConf.get.mergeUpdateStructsByField && updateStar) {
+ val value = exactAssignments.head.value
+ col.dataType match {
+ case structType: StructType =>
+ // Expand assignments to leaf fields
+ val structAssignment =
+ applyNestedFieldAssignments(col, colExpr, value, addError,
colPath,
+ coerceNestedTypes)
+
+ // Wrap with null check for missing source fields
+ fixNullExpansion(col, value, structType, structAssignment,
+ colPath, addError)
+ case _ =>
+ // For non-struct types, resolve directly
+ val coerceMode = if (coerceNestedTypes) RECURSE else NONE
+ TableOutputResolver.resolveUpdate("", value, col, conf, addError,
colPath,
+ coerceMode)
+ }
+ } else {
+ val value = exactAssignments.head.value
+ val coerceMode = if (coerceNestedTypes) RECURSE else NONE
+ val resolvedValue = TableOutputResolver.resolveUpdate("", value, col,
conf, addError,
+ colPath, coerceMode)
+ resolvedValue
+ }
} else {
applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath,
coerceNestedTypes)
}
@@ -210,6 +241,63 @@ object AssignmentUtils extends SQLConfHelper with
CastSupport {
}
}
+ private def applyNestedFieldAssignments(
+ col: Attribute,
+ colExpr: Expression,
+ value: Expression,
+ addError: String => Unit,
+ colPath: Seq[String],
+ coerceNestedTyptes: Boolean): Expression = {
+
+ col.dataType match {
+ case structType: StructType =>
+ val fieldAttrs = DataTypeUtils.toAttributes(structType)
+
+ val updatedFieldExprs = fieldAttrs.zipWithIndex.map { case (fieldAttr,
ordinal) =>
+ val fieldPath = colPath :+ fieldAttr.name
+ val targetFieldExpr = GetStructField(colExpr, ordinal,
Some(fieldAttr.name))
+
+ // Try to find a corresponding field in the source value by name
+ val sourceFieldValue: Expression = value.dataType match {
+ case valueStructType: StructType =>
+ valueStructType.fields.find(f => conf.resolver(f.name,
fieldAttr.name)) match {
+ case Some(matchingField) =>
+ // Found matching field in source, extract it
+ val fieldIndex =
valueStructType.fieldIndex(matchingField.name)
+ GetStructField(value, fieldIndex, Some(matchingField.name))
+ case None =>
+ // Field doesn't exist in source, use target's current value
with null check
+ TableOutputResolver.checkNullability(targetFieldExpr,
fieldAttr, conf, fieldPath)
+ }
+ case _ =>
+ // Value is not a struct, cannot extract field
+ addError(s"Cannot assign non-struct value to struct field
'${fieldPath.quoted}'")
+ Literal(null, fieldAttr.dataType)
+ }
+
+ // Recurse or resolve based on field type
+ fieldAttr.dataType match {
+ case nestedStructType: StructType =>
+ // Field is a struct, recurse
+ applyNestedFieldAssignments(fieldAttr, targetFieldExpr,
sourceFieldValue,
+ addError, fieldPath, coerceNestedTyptes)
+ case _ =>
+ // Field is not a struct, resolve with TableOutputResolver
+ val coerceMode = if (coerceNestedTyptes) RECURSE else NONE
+ TableOutputResolver.resolveUpdate("", sourceFieldValue,
fieldAttr, conf, addError,
+ fieldPath, coerceMode)
+ }
+ }
+ toNamedStruct(structType, updatedFieldExprs)
+
+ case otherType =>
+ addError(
+ "Updating nested fields is only supported for StructType but " +
+ s"'${colPath.quoted}' is of type $otherType")
+ colExpr
+ }
+ }
+
private def toNamedStruct(structType: StructType, fieldExprs:
Seq[Expression]): Expression = {
val namedStructExprs = structType.fields.zip(fieldExprs).flatMap { case
(field, expr) =>
Seq(Literal(field.name), expr)
@@ -217,6 +305,101 @@ object AssignmentUtils extends SQLConfHelper with
CastSupport {
CreateNamedStruct(namedStructExprs)
}
+ private def getMissingSourcePaths(targetType: StructType,
+ sourceType: DataType,
+ colPath: Seq[String],
+ addError: String => Unit): Seq[Seq[String]] =
{
+ val nestedTargetPaths = DataTypeUtils.extractLeafFieldPaths(targetType,
Seq.empty)
+ val nestedSourcePaths = sourceType match {
+ case sourceStructType: StructType =>
+ DataTypeUtils.extractLeafFieldPaths(sourceStructType, Seq.empty)
+ case _ =>
+ addError(s"Value for struct type: " +
+ s"${colPath.quoted} must be a struct but was
${sourceType.simpleString}")
+ Seq()
+ }
+ nestedSourcePaths.diff(nestedTargetPaths)
+ }
+
+ /**
+ * Creates a null check for a field at the given path within a struct
expression.
+ * Navigates through the struct hierarchy following the path and returns an
IsNull check
+ * for the final field.
+ *
+ * @param rootExpr the root expression to navigate from
+ * @param path the field path to navigate (sequence of field names)
+ * @return an IsNull expression checking if the field at the path is null
+ */
+ private def createNullCheckForFieldPath(
+ rootExpr: Expression,
+ path: Seq[String]): Expression = {
+ var currentExpr: Expression = rootExpr
+ path.foreach { fieldName =>
+ currentExpr.dataType match {
+ case st: StructType =>
+ st.fields.find(f => conf.resolver(f.name, fieldName)) match {
+ case Some(field) =>
+ val fieldIndex = st.fieldIndex(field.name)
+ currentExpr = GetStructField(currentExpr, fieldIndex,
Some(field.name))
+ case None =>
+ // Field not found, shouldn't happen
+ }
+ case _ =>
+ // Not a struct, shouldn't happen
+ }
+ }
+ IsNull(currentExpr)
+ }
+
+ /**
+ * As UPDATE SET * can assign struct fields individually (preserving
existing fields),
+ * this will lead to null expansion, ie, a struct is created where all
fields are null.
+ * Wraps a struct assignment with null checks for the source and missing
source fields.
+ * Return null if all are null.
+ *
+ * @param col the target column attribute
+ * @param value the source value expression
+ * @param structType the target struct type
+ * @param structAssignment the struct assignment result to wrap
+ * @param colPath the column path for error reporting
+ * @param addError error reporting function
+ * @return the wrapped expression with null checks
+ */
+ private def fixNullExpansion(
+ col: Attribute,
+ value: Expression,
+ structType: StructType,
+ structAssignment: Expression,
+ colPath: Seq[String],
+ addError: String => Unit): Expression = {
+ // As StoreAssignmentPolicy.LEGACY is not allowed in DSv2, always add null
check for
+ // non-nullable column
+ if (!col.nullable) {
+ AssertNotNull(value)
+ } else {
+ // Check if source struct is null
+ val valueIsNull = IsNull(value)
+
+ // Check if missing source paths (paths in target but not in source) are
not null
+ // These will be null for the case of UPDATE SET * and
+ val missingSourcePaths = getMissingSourcePaths(structType,
value.dataType, colPath, addError)
+ val condition = if (missingSourcePaths.nonEmpty) {
+ // Check if all target attributes at missing source paths are null
+ val missingFieldNullChecks = missingSourcePaths.map { path =>
+ createNullCheckForFieldPath(col, path)
+ }
+ // Combine all null checks with AND
+ val allMissingFieldsNull =
missingFieldNullChecks.reduce[Expression]((a, b) => And(a, b))
+ And(valueIsNull, allMissingFieldsNull)
+ } else {
+ valueIsNull
+ }
+
+ // Return: If (condition) THEN NULL ELSE structAssignment
+ If(condition, Literal(null, structAssignment.dataType), structAssignment)
+ }
+ }
+
/**
* Checks whether assignments are aligned and compatible with table columns.
*
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
index 3eb528954b35..93ef98e3183a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
@@ -44,7 +44,7 @@ object ResolveRowLevelCommandAssignments extends
Rule[LogicalPlan] {
validateStoreAssignmentPolicy()
val newTable = cleanAttrMetadata(u.table)
val newAssignments =
AssignmentUtils.alignUpdateAssignments(u.table.output, u.assignments,
- coerceNestedTypes = false)
+ fromStar = false, coerceNestedTypes = false)
u.copy(table = newTable, assignments = newAssignments)
case u: UpdateTable if !u.skipSchemaResolution && u.resolved && !u.aligned
=>
@@ -53,10 +53,11 @@ object ResolveRowLevelCommandAssignments extends
Rule[LogicalPlan] {
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved &&
m.rewritable && !m.aligned &&
!m.needSchemaEvolution =>
validateStoreAssignmentPolicy()
- val coerceNestedTypes = SQLConf.get.coerceMergeNestedTypes
+ val coerceNestedTypes = SQLConf.get.mergeCoerceNestedTypes
m.copy(
targetTable = cleanAttrMetadata(m.targetTable),
- matchedActions = alignActions(m.targetTable.output, m.matchedActions,
coerceNestedTypes),
+ matchedActions = alignActions(m.targetTable.output, m.matchedActions,
+ coerceNestedTypes),
notMatchedActions = alignActions(m.targetTable.output,
m.notMatchedActions,
coerceNestedTypes),
notMatchedBySourceActions = alignActions(m.targetTable.output,
m.notMatchedBySourceActions,
@@ -117,9 +118,9 @@ object ResolveRowLevelCommandAssignments extends
Rule[LogicalPlan] {
actions: Seq[MergeAction],
coerceNestedTypes: Boolean): Seq[MergeAction] = {
actions.map {
- case u @ UpdateAction(_, assignments) =>
+ case u @ UpdateAction(_, assignments, fromStar) =>
u.copy(assignments = AssignmentUtils.alignUpdateAssignments(attrs,
assignments,
- coerceNestedTypes))
+ fromStar, coerceNestedTypes))
case d: DeleteAction =>
d
case i @ InsertAction(_, assignments) =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
index 8b5b690aa740..1d2e2fef2096 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
@@ -334,7 +334,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand
with PredicateHelper
// original row ID values must be preserved and passed back to the table
to encode updates
// if there are any assignments to row ID attributes, add extra columns
for original values
val updateAssignments = (matchedActions ++
notMatchedBySourceActions).flatMap {
- case UpdateAction(_, assignments) => assignments
+ case UpdateAction(_, assignments, _) => assignments
case _ => Nil
}
buildOriginalRowIdValues(rowIdAttrs, updateAssignments)
@@ -434,7 +434,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand
with PredicateHelper
// converts a MERGE action into an instruction on top of the joined plan for
group-based plans
private def toInstruction(action: MergeAction, metadataAttrs:
Seq[Attribute]): Instruction = {
action match {
- case UpdateAction(cond, assignments) =>
+ case UpdateAction(cond, assignments, _) =>
val rowValues = assignments.map(_.value)
val metadataValues = nullifyMetadataOnUpdate(metadataAttrs)
val output = Seq(Literal(WRITE_WITH_METADATA_OPERATION)) ++ rowValues
++ metadataValues
@@ -466,12 +466,12 @@ object RewriteMergeIntoTable extends
RewriteRowLevelCommand with PredicateHelper
splitUpdates: Boolean): Instruction = {
action match {
- case UpdateAction(cond, assignments) if splitUpdates =>
+ case UpdateAction(cond, assignments, _) if splitUpdates =>
val output = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs,
originalRowIdValues)
val otherOutput = deltaReinsertOutput(assignments, metadataAttrs,
originalRowIdValues)
Split(cond.getOrElse(TrueLiteral), output, otherOutput)
- case UpdateAction(cond, assignments) =>
+ case UpdateAction(cond, assignments, _) =>
val output = deltaUpdateOutput(assignments, metadataAttrs,
originalRowIdValues)
Keep(Update, cond.getOrElse(TrueLiteral), output)
@@ -495,7 +495,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand
with PredicateHelper
val actions = merge.matchedActions ++ merge.notMatchedActions ++
merge.notMatchedBySourceActions
actions.foreach {
case DeleteAction(Some(cond)) => checkMergeIntoCondition("DELETE", cond)
- case UpdateAction(Some(cond), _) => checkMergeIntoCondition("UPDATE",
cond)
+ case UpdateAction(Some(cond), _, _) => checkMergeIntoCondition("UPDATE",
cond)
case InsertAction(Some(cond), _) => checkMergeIntoCondition("INSERT",
cond)
case _ => // OK
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
index 7134c3daf3ba..9a676571d107 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
@@ -149,7 +149,8 @@ object ReplaceNullWithFalseInPredicate extends
Rule[LogicalPlan] {
private def replaceNullWithFalse(mergeActions: Seq[MergeAction]):
Seq[MergeAction] = {
mergeActions.map {
- case u @ UpdateAction(Some(cond), _) => u.copy(condition =
Some(replaceNullWithFalse(cond)))
+ case u @ UpdateAction(Some(cond), _, _) =>
+ u.copy(condition = Some(replaceNullWithFalse(cond)))
case u @ UpdateStarAction(Some(cond)) => u.copy(condition =
Some(replaceNullWithFalse(cond)))
case d @ DeleteAction(Some(cond)) => d.copy(condition =
Some(replaceNullWithFalse(cond)))
case i @ InsertAction(Some(cond), _) => i.copy(condition =
Some(replaceNullWithFalse(cond)))
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index dcce22040244..26ce138523e7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -873,7 +873,7 @@ case class MergeIntoTable(
lazy val aligned: Boolean = {
val actions = matchedActions ++ notMatchedActions ++
notMatchedBySourceActions
actions.forall {
- case UpdateAction(_, assignments) =>
+ case UpdateAction(_, assignments, _) =>
AssignmentUtils.aligned(targetTable.output, assignments)
case _: DeleteAction =>
true
@@ -926,10 +926,7 @@ case class MergeIntoTable(
case a: UpdateAction => a.assignments
case a: InsertAction => a.assignments
}.flatten
-
- val sourcePaths = MergeIntoTable.extractAllFieldPaths(sourceTable.schema)
- // Only allow unresolved assignment keys to be candidates for schema
evolution
- // if they are directly assigned from source fields, ie UPDATE SET new =
source.new
+ val sourcePaths = DataTypeUtils.extractAllFieldPaths(sourceTable.schema)
assignments.forall { assignment =>
assignment.resolved ||
(assignment.value.resolved && sourcePaths.exists {
@@ -1083,19 +1080,6 @@ object MergeIntoTable {
filterSchema(merge.sourceTable.schema, Seq.empty)
}
- private def extractAllFieldPaths(schema: StructType, basePath: Seq[String] =
Seq.empty):
- Seq[Seq[String]] = {
- schema.flatMap { field =>
- val fieldPath = basePath :+ field.name
- field.dataType match {
- case struct: StructType =>
- fieldPath +: extractAllFieldPaths(struct, fieldPath)
- case _ =>
- Seq(fieldPath)
- }
- }
- }
-
// Helper method to extract field path from an Expression.
private def extractFieldPath(expr: Expression, allowUnresolved: Boolean):
Seq[String] = {
expr match {
@@ -1142,7 +1126,8 @@ case class DeleteAction(condition: Option[Expression])
extends MergeAction {
case class UpdateAction(
condition: Option[Expression],
- assignments: Seq[Assignment]) extends MergeAction {
+ assignments: Seq[Assignment],
+ fromStar: Boolean = false) extends MergeAction {
override def children: Seq[Expression] = condition.toSeq ++ assignments
override protected def withNewChildrenInternal(
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
index c6e51aab4584..e7bd5bd1aa2d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
@@ -249,5 +249,38 @@ object DataTypeUtils {
case v: Long => fromDecimal(Decimal(BigDecimal(v)))
case _ => forType(literal.dataType)
}
+
+ /**
+ * Extracts all struct field paths from a nested StructType.
+ */
+ def extractAllFieldPaths(schema: StructType, basePath: Seq[String] =
Seq.empty):
+ Seq[Seq[String]] = {
+ schema.flatMap { field =>
+ val fieldPath = basePath :+ field.name
+ field.dataType match {
+ case struct: StructType =>
+ fieldPath +: extractAllFieldPaths(struct, fieldPath)
+ case _ =>
+ Seq(fieldPath)
+ }
+ }
+ }
+
+ /**
+ * Extracts only leaf-level field paths from a nested StructType.
+ * Unlike extractAllFieldPaths, this method does not include intermediate
struct paths.
+ */
+ def extractLeafFieldPaths(schema: StructType, basePath: Seq[String] =
Seq.empty):
+ Seq[Seq[String]] = {
+ schema.flatMap { field =>
+ val fieldPath = basePath :+ field.name
+ field.dataType match {
+ case struct: StructType =>
+ extractLeafFieldPaths(struct, fieldPath)
+ case _ =>
+ Seq(fieldPath)
+ }
+ }
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 5ae5fa33a721..9633011cf872 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -6611,8 +6611,8 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
- val MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED =
- buildConf("spark.sql.merge.source.nested.type.coercion.enabled")
+ val MERGE_INTO_NESTED_TYPE_COERCION_ENABLED =
+ buildConf("spark.sql.merge.nested.type.coercion.enabled")
.internal()
.doc("If enabled, allow MERGE INTO to coerce source nested types if they
have less" +
"nested fields than the target table's nested types.")
@@ -6620,6 +6620,18 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD =
+ buildConf("spark.sql.merge.nested.type.assign.by.field")
+ .internal()
+ .doc("If enabled and spark.sql.merge.source.nested.type.coercion.enabled
is true," +
+ "allow MERGE INTO with UPDATE SET * action to set nested structs field
by field. " +
+ "In updated rows, target structs will preserve the original value for
fields missing " +
+ "in the the source struct. If disabled, the entire target struct will
be replaced, " +
+ "and fields missing in the source struct will be null.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(true)
+
/**
* Holds information about keys that have been deprecated.
*
@@ -7777,8 +7789,11 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def legacyXMLParserEnabled: Boolean =
getConf(SQLConf.LEGACY_XML_PARSER_ENABLED)
- def coerceMergeNestedTypes: Boolean =
- getConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED)
+ def mergeCoerceNestedTypes: Boolean =
+ getConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED)
+
+ def mergeUpdateStructsByField: Boolean =
+ getConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD)
/** ********************** SQLConf functionality methods ************ */
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
index 85b0faed4c38..7051a0b455e3 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
@@ -3231,60 +3231,180 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase
test("merge into schema evolution replace column with nested struct and set
all columns") {
Seq(true, false).foreach { withSchemaEvolution =>
- withTempView("source") {
- createAndInitTable(
- s"""pk INT NOT NULL,
- |s STRUCT<c1: INT, c2: STRUCT<a: ARRAY<INT>, m: MAP<STRING,
STRING>>>,
- |dep STRING""".stripMargin,
- """{ "pk": 1, "s": { "c1": 2, "c2": { "a": [1,2], "m": { "a": "b" }
} }, "dep": "hr" }""")
+ Seq(true, false).foreach { updateByFields =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key ->
+ updateByFields.toString) {
+ withTempView("source") {
+ // Create table using Spark SQL
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ |pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: ARRAY<INT>, m: MAP<STRING,
STRING>>>,
+ |dep STRING)
+ |PARTITIONED BY (dep)
+ |""".stripMargin)
- val sourceTableSchema = StructType(Seq(
- StructField("pk", IntegerType, nullable = false),
- StructField("s", StructType(Seq(
- StructField("c1", IntegerType),
- StructField("c2", StructType(Seq(
- // missing column 'a'
- StructField("m", MapType(StringType, StringType)),
- StructField("c3", BooleanType) // new column
- )))
- ))),
- StructField("dep", StringType)
- ))
- val data = Seq(
- Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"),
- Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering")
- )
- spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
- .createOrReplaceTempView("source")
+ // Insert data using DataFrame API with objects
+ val tableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", ArrayType(IntegerType)),
+ StructField("m", MapType(StringType, StringType))
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ val targetData = Seq(
+ Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
tableSchema)
+ .coalesce(1).writeTo(tableNameAsString).append()
+
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ // missing column 'a'
+ StructField("m", MapType(StringType, StringType)),
+ StructField("c3", BooleanType) // new column
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ val sourceData = Seq(
+ Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"),
+ Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(sourceData),
sourceTableSchema)
+ .createOrReplaceTempView("source")
- val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA
EVOLUTION" else ""
- val mergeStmt =
- s"""MERGE $schemaEvolutionClause
- |INTO $tableNameAsString t
- |USING source src
- |ON t.pk = src.pk
- |WHEN MATCHED THEN
- | UPDATE SET *
- |WHEN NOT MATCHED THEN
- | INSERT *
- |""".stripMargin
- if (withSchemaEvolution) {
- sql(mergeStmt)
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"),
- Row(2, Row(20, Row(null, Map("e" -> "f"), true)),
"engineering")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- sql(mergeStmt)
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA
EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause
+ |INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
+ if (withSchemaEvolution) {
+ sql(mergeStmt)
+ if (updateByFields) {
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, Row(10, Row(Seq(1, 2), Map("c" -> "d"), false)),
"sales"),
+ Row(2, Row(20, Row(null, Map("e" -> "f"), true)),
"engineering")))
+ } else {
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, Row(10, Row(null, Map("c" -> "d"), false)),
"sales"),
+ Row(2, Row(20, Row(null, Map("e" -> "f"), true)),
"engineering")))
+ }
+ } else {
+ val exception =
intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
+ assert(exception.getMessage.contains(
+ "Cannot write extra fields `c3` to the struct `s`.`c2`"))
+ }
}
- assert(exception.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
- assert(exception.getMessage.contains(
- "Cannot write extra fields `c3` to the struct `s`.`c2`"))
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
+ }
+ }
+
+ test("merge into schema evolution replace column with nested struct and
update " +
+ "top level struct") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { updateByFields =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key ->
+ updateByFields.toString) {
+ withTempView("source") {
+ // Create table using Spark SQL
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ |pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: ARRAY<INT>, m: MAP<STRING,
STRING>>>,
+ |dep STRING)
+ |PARTITIONED BY (dep)
+ |""".stripMargin)
+
+ // Insert data using DataFrame API with objects
+ val tableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", ArrayType(IntegerType)),
+ StructField("m", MapType(StringType, StringType))
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ val targetData = Seq(
+ Row(1, Row(2, Row(Array(1, 2), Map("a" -> "b"))), "hr")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
tableSchema)
+ .coalesce(1).writeTo(tableNameAsString).append()
+
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ // missing column 'a'
+ StructField("m", MapType(StringType, StringType)),
+ StructField("c3", BooleanType) // new column
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ val sourceData = Seq(
+ Row(1, Row(10, Row(Map("c" -> "d"), false)), "sales"),
+ Row(2, Row(20, Row(Map("e" -> "f"), true)), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(sourceData),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA
EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause
+ |INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET s = src.s
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
+ if (withSchemaEvolution) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"),
+ Row(2, Row(20, Row(null, Map("e" -> "f"), true)),
"engineering")))
+ } else {
+ val exception =
intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
+ assert(exception.getMessage.contains(
+ "Cannot write extra fields `c3` to the struct `s`.`c2`"))
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
}
- sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
}
@@ -4343,50 +4463,6 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase
}
}
- test("merge into with source missing fields in top-level struct") {
- withTempView("source") {
- // Target table has struct with 3 fields at top level
- createAndInitTable(
- s"""pk INT NOT NULL,
- |s STRUCT<c1: INT, c2: STRING, c3: BOOLEAN>,
- |dep STRING""".stripMargin,
- """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep":
"sales"}""")
-
- // Source table has struct with only 2 fields (c1, c2) - missing c3
- val sourceTableSchema = StructType(Seq(
- StructField("pk", IntegerType, nullable = false),
- StructField("s", StructType(Seq(
- StructField("c1", IntegerType),
- StructField("c2", StringType)))), // missing c3 field
- StructField("dep", StringType)))
- val data = Seq(
- Row(1, Row(10, "b"), "hr"),
- Row(2, Row(20, "c"), "engineering")
- )
- spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
- .createOrReplaceTempView("source")
-
- sql(
- s"""MERGE INTO $tableNameAsString t
- |USING source src
- |ON t.pk = src.pk
- |WHEN MATCHED THEN
- | UPDATE SET *
- |WHEN NOT MATCHED THEN
- | INSERT *
- |""".stripMargin)
-
- // Missing field c3 should be filled with NULL
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(0, Row(1, "a", true), "sales"),
- Row(1, Row(10, "b", null), "hr"),
- Row(2, Row(20, "c", null), "engineering")))
- }
- sql(s"DROP TABLE IF EXISTS $tableNameAsString")
- }
-
test("merge into with source missing fields in struct nested in array") {
withTempView("source") {
// Target table has struct with 3 fields (c1, c2, c3) in array
@@ -4540,22 +4616,459 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase
sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
- test("merge into with source missing fields in nested struct") {
- Seq(true, false).foreach { nestedTypeCoercion =>
- withSQLConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED.key
- -> nestedTypeCoercion.toString) {
+ test("merge into with source missing fields in top-level struct") {
+ withTempView("source") {
+ // Target table has struct with 3 fields at top level
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRING, c3: BOOLEAN>,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep":
"sales"}""")
+
+ // Source table has struct with only 2 fields (c1, c2) - missing c3
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType)))), // missing c3 field
+ StructField("dep", StringType)))
+ val data = Seq(
+ Row(1, Row(10, "b"), "hr"),
+ Row(2, Row(20, "c"), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin)
+
+ // Missing field c3 should be filled with NULL
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, "a", true), "sales"),
+ Row(1, Row(10, "b", null), "hr"),
+ Row(2, Row(20, "c", null), "engineering")))
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+
+ test("merge with null struct") {
+ Seq(true, false).foreach { updateByFields =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key ->
+ updateByFields.toString) {
+ withTempView("source") {
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRING>,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "s": { "c1": 1, "c2": "a" }, "dep": "sales" }
+ |{ "pk": 1, "s": { "c1": 2, "c2": "b" }, "dep": "hr" }"""
+ .stripMargin)
+
+ // Source table matches target table schema
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType)
+ ))),
+ StructField("dep", StringType)
+ ))
+
+ val data = Seq(
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ sql(
+ s"""MERGE INTO $tableNameAsString t USING source
+ |ON t.pk = source.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, "a"), "sales"),
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")))
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
+
+ test("merge with null struct - update field") {
+ Seq(true, false).foreach { updateByFields =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key ->
+ updateByFields.toString) {
+ withTempView("source") {
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRING>,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "s": { "c1": 1, "c2": "a" }, "dep": "sales" }
+ |{ "pk": 1, "s": { "c1": 2, "c2": "b" }, "dep": "hr" }"""
+ .stripMargin)
+
+ // Source table matches target table schema
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType)
+ ))),
+ StructField("dep", StringType)
+ ))
+
+ val data = Seq(
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ sql(
+ s"""MERGE INTO $tableNameAsString t USING source
+ |ON t.pk = source.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET s = source.s
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, "a"), "sales"),
+ Row(1, null, "hr"),
+ Row(2, null, "finance")))
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
+
+ test("merge with null struct into non-nullable struct column") {
+ Seq(true, false).foreach { updateByFields =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key ->
+ updateByFields.toString) {
withTempView("source") {
- // Target table has nested struct: s.c1, s.c2.a, s.c2.b
createAndInitTable(
s"""pk INT NOT NULL,
- |s STRUCT<c1: INT, c2: STRUCT<a: INT, b: BOOLEAN>>,
+ |s STRUCT<c1: INT, c2: STRING> NOT NULL,
|dep STRING""".stripMargin,
- """{ "pk": 1, "s": { "c1": 2, "c2": { "a": 10, "b": true } } }
- |{ "pk": 2, "s": { "c1": 2, "c2": { "a": 30, "b": false } }
}""".stripMargin)
+ """{ "pk": 0, "s": { "c1": 1, "c2": "a" }, "dep": "sales" }
+ |{ "pk": 1, "s": { "c1": 2, "c2": "b" }, "dep": "hr" }"""
+ .stripMargin)
- // Source table is missing field 'b' in nested struct s.c2
+ // Source table has null for the struct column
val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType)
+ ))),
+ StructField("dep", StringType)
+ ))
+
+ val data = Seq(
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ // Should throw an exception when trying to insert/update null into
NOT NULL column
+ val exception = intercept[Exception] {
+ sql(
+ s"""MERGE INTO $tableNameAsString t USING source
+ |ON t.pk = source.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin)
+ }
+ assert(exception.getMessage.contains(
+ "NULL value appeared in non-nullable field"))
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
+
+ test("merge with with null struct with missing nested field") {
+ Seq(true, false).foreach { updateByFields =>
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(
+ SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key ->
+ updateByFields.toString,
+ SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ withTempView("source") {
+ // Target table has nested struct with fields c1 and c2
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: INT, b: STRING>>,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "s": { "c1": 1, "c2": { "a": 10, "b": "x" } },
"dep": "sales" }
+ |{ "pk": 1, "s": { "c1": 2, "c2": { "a": 20, "b": "y" } },
"dep": "hr" }"""
+ .stripMargin)
+
+ // Source table has null for the nested struct
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType)
+ // missing field 'b'
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+
+ val data = Seq(
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ val mergeStmt =
+ s"""MERGE INTO $tableNameAsString t USING source
+ |ON t.pk = source.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
+
+ if (coerceNestedTypes) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, Row(10, "x")), "sales"),
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")))
+ } else {
+ // Without coercion, the merge should fail due to missing field
+ val exception =
intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
+ assert(exception.getMessage.contains(
+ "Cannot write incompatible data for the table ``: " +
+ "Cannot find data for the output column `s`.`c2`.`b`."))
+ }
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
+ }
+
+ test("merge null struct with schema evolution - source with missing and
extra nested fields") {
+ Seq(true, false).foreach { updateByFields =>
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(
+ SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key ->
+ updateByFields.toString,
+ SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ withTempView("source") {
+ // Target table has nested struct with fields c1 and c2
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: INT, b: STRING>>,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "s": { "c1": 1, "c2": { "a": 10, "b": "x" } },
"dep": "sales" }
+ |{ "pk": 1, "s": { "c1": 2, "c2": { "a": 20, "b": "y" } },
"dep": "hr" }"""
+ .stripMargin)
+
+ // Source table has missing field 'b' and extra field 'c' in
nested struct
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType),
+ // missing field 'b'
+ StructField("c", StringType) // extra field 'c'
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+
+ val data = Seq(
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH
SCHEMA EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause
+ |INTO $tableNameAsString t USING source
+ |ON t.pk = source.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
+
+ if (coerceNestedTypes) {
+ if (withSchemaEvolution) {
+ // extra nested field is added
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, Row(10, "x", null)), "sales"),
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")))
+ } else {
+ // extra nested field is not added
+ val exception =
intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
+ assert(exception.getMessage.contains(
+ "Cannot write incompatible data for the table ``: " +
+ "Cannot write extra fields `c` to the struct `s`.`c2`"))
+ }
+ } else {
+ // Without source struct coercion, the merge should fail
+ val exception =
intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
+ assert(exception.getMessage.contains(
+ "Cannot write incompatible data for the table ``: " +
+ "Cannot find data for the output column `s`.`c2`.`b`."))
+ }
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
+ }
+ }
+
+ test("merge null struct with non-nullable nested field - source with missing
" +
+ "and extra nested fields") {
+
+ withSQLConf(
+ SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key -> "true",
+ SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> "true") {
+ withTempView("source") {
+ // Target table has nested struct with NON-NULLABLE field b
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: INT, b: STRING NOT NULL>>,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "s": { "c1": 1, "c2": { "a": 10, "b": "x" } }, "dep":
"sales" }
+ |{ "pk": 1, "s": { "c1": 2, "c2": { "a": 20, "b": "y" } }, "dep":
"hr" }"""
+ .stripMargin)
+
+ // Source table has missing field 'b' and extra field 'c' in nested
struct
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType),
+ // missing field 'b' (which is non-nullable in target)
+ StructField("c", StringType) // extra field 'c'
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+
+ val data = Seq(
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ val mergeStmt =
+ s"""MERGE WITH SCHEMA EVOLUTION
+ |INTO $tableNameAsString t USING source
+ |ON t.pk = source.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
+
+ // All cases should fail due to non-nullable constraint violation
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
+ assert(exception.getMessage.contains("Cannot write incompatible data
for the table ``: " +
+ "Cannot find data for the output column `s`.`c2`.`b`."))
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
+
+ test("merge with null struct using default value") {
+ Seq(true, false).foreach { updateByFields =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key ->
+ updateByFields.toString) {
+ withTempView("source") {
+ // Target table has nested struct with a default value
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ | pk INT NOT NULL,
+ | s STRUCT<c1: INT, c2: STRUCT<a: INT, b: STRING>> DEFAULT
+ | named_struct('c1', 999, 'c2', named_struct('a', 999, 'b',
'default')),
+ | dep STRING)
+ |PARTITIONED BY (dep)
+ |""".stripMargin)
+
+ // Insert initial data using DataFrame API
+ val initialSchema = StructType(Seq(
StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ val initialData = Seq(
+ Row(0, Row(1, Row(10, "x")), "sales"),
+ Row(1, Row(2, Row(20, "y")), "hr")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(initialData),
initialSchema)
+ .writeTo(tableNameAsString).append()
+
+ // Source table has null for the nested struct
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
StructField("s", StructType(Seq(
StructField("c1", IntegerType),
StructField("c2", StructType(Seq(
@@ -4565,45 +5078,179 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase
))),
StructField("dep", StringType)
))
+
val data = Seq(
- Row(1, Row(10, Row(20)), "sales"),
- Row(2, Row(20, Row(30)), "engineering")
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")
)
spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
.createOrReplaceTempView("source")
- // Missing field b should be filled with NULL
- val mergeStmt = s"""MERGE INTO $tableNameAsString t
- |USING source src
- |ON t.pk = src.pk
- |WHEN MATCHED THEN
- | UPDATE SET *
- |WHEN NOT MATCHED THEN
- | INSERT *
- |""".stripMargin
+ sql(
+ s"""MERGE INTO $tableNameAsString t USING source
+ |ON t.pk = source.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, Row(10, "x")), "sales"),
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")))
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
- if (nestedTypeCoercion) {
- sql(mergeStmt)
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(1, Row(10, Row(20, null)), "sales"),
- Row(2, Row(20, Row(30, null)), "engineering")))
- } else {
- val exception = intercept[Exception] {
+ test("merge with source missing struct column with default value") {
+ Seq(true, false).foreach { updateByFields =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key ->
+ updateByFields.toString) {
+ withTempView("source") {
+ // Target table has nested struct with a default value
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ | pk INT NOT NULL,
+ | s STRUCT<c1: INT, c2: STRUCT<a: INT, b: STRING>> DEFAULT
+ | named_struct('c1', 999, 'c2', named_struct('a', 999, 'b',
'default')),
+ | dep STRING)
+ |PARTITIONED BY (dep)
+ |""".stripMargin)
+
+ // Insert initial data using DataFrame API
+ val initialSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ val initialData = Seq(
+ Row(0, Row(1, Row(10, "x")), "sales"),
+ Row(1, Row(2, Row(20, "y")), "hr")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(initialData),
initialSchema)
+ .writeTo(tableNameAsString).append()
+
+ // Source table is completely missing the struct column 's'
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("dep", StringType)
+ ))
+
+ val data = Seq(
+ Row(1, "engineering"),
+ Row(2, "finance")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ // When inserting without specifying the struct column, default
should be used
+ sql(
+ s"""MERGE INTO $tableNameAsString t USING source
+ |ON t.pk = source.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET dep = source.dep
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, dep) VALUES (source.pk, source.dep)
+ |""".stripMargin)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, Row(10, "x")), "sales"),
+ Row(1, Row(2, Row(20, "y")), "engineering"),
+ Row(2, Row(999, Row(999, "default")), "finance")))
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
+
+ test("merge into with source missing fields in nested struct") {
+ Seq(true, false).foreach { nestedTypeCoercion =>
+ Seq(true, false).foreach { updateByFields =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_UPDATE_BY_FIELD.key
+ -> updateByFields.toString,
+ SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key
+ -> nestedTypeCoercion.toString) {
+ withTempView("source") {
+ // Target table has nested struct: s.c1, s.c2.a, s.c2.b
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: INT, b: BOOLEAN>>,
+ |dep STRING""".stripMargin,
+ """{ "pk": 1, "s": { "c1": 2, "c2": { "a": 10, "b": true } } }
+ |{ "pk": 2, "s": { "c1": 2, "c2": { "a": 30, "b": false } }
}""".stripMargin)
+
+ // Source table is missing field 'b' in nested struct s.c2
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType)
+ // missing field 'b'
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ val data = Seq(
+ Row(1, Row(10, Row(20)), "sales"),
+ Row(2, Row(20, Row(30)), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ // Missing field b should be filled with NULL
+ val mergeStmt = s"""MERGE INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
+
+ if (nestedTypeCoercion) {
sql(mergeStmt)
+ if (updateByFields) {
+ // When updating by fields, only non-null fields are updated
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, Row(10, Row(20, true)), "sales"),
+ Row(2, Row(20, Row(30, false)), "engineering")))
+ } else {
+ // When updating by top level column, the missing field is set
to NULL
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, Row(10, Row(20, null)), "sales"),
+ Row(2, Row(20, Row(30, null)), "engineering")))
+ }
+ } else {
+ val exception = intercept[Exception] {
+ sql(mergeStmt)
+ }
+ assert(exception.getMessage.contains(
+ """Cannot write incompatible data for the table
``""".stripMargin))
}
- assert(exception.getMessage.contains(
- """Cannot write incompatible data for the table
``""".stripMargin))
}
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
- sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
}
}
-
- test("merge with named_struct missing non-nullable field backup") {
+ test("merge with named_struct missing non-nullable field") {
withTempView("source") {
createAndInitTable(
s"""pk INT NOT NULL,
@@ -4631,7 +5278,7 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase
.createOrReplaceTempView("source")
Seq(true, false).foreach { coerceNestedTypes =>
- withSQLConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED.key
->
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
coerceNestedTypes.toString) {
// Test UPDATE with named_struct missing non-nullable field c2
val e = intercept[AnalysisException] {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala
index 8420e5e4d880..f635131dc3f7 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala
@@ -40,7 +40,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
| UPDATE SET t.txt = "error", t.i = CAST(null AS INT)""".stripMargin)
matchedActions match {
- case Seq(UpdateAction(None, assignments)) =>
+ case Seq(UpdateAction(None, assignments, _)) =>
assignments match {
case Seq(
Assignment(i: AttributeReference, AssertNotNull(iValue:
AttributeReference, _)),
@@ -80,7 +80,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
}
notMatchedBySourceActions match {
- case Seq(UpdateAction(None, assignments)) =>
+ case Seq(UpdateAction(None, assignments, _)) =>
assignments match {
case Seq(
Assignment(i: AttributeReference, AssertNotNull(_: Cast, _)),
@@ -138,7 +138,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
}
matchedActions match {
- case Seq(UpdateAction(None, assignments)) =>
+ case Seq(UpdateAction(None, assignments, _)) =>
assignments match {
case Seq(
Assignment(i: AttributeReference, iValue: AttributeReference),
@@ -184,7 +184,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
}
notMatchedBySourceActions match {
- case Seq(UpdateAction(None, assignments)) =>
+ case Seq(UpdateAction(None, assignments, _)) =>
assignments match {
case Seq(
Assignment(i: AttributeReference, iValue: AttributeReference),
@@ -217,7 +217,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
|""".stripMargin)
matchedActions match {
- case Seq(UpdateAction(None, assignments)) =>
+ case Seq(UpdateAction(None, assignments, _)) =>
assignments match {
case Seq(
Assignment(i: AttributeReference, iValue: AttributeReference),
@@ -280,7 +280,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
}
matchedActions match {
- case Seq(UpdateAction(None, assignments)) =>
+ case Seq(UpdateAction(None, assignments, _)) =>
assignments match {
case Seq(
Assignment(i: AttributeReference, iValue: AttributeReference),
@@ -342,7 +342,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
}
notMatchedBySourceActions match {
- case Seq(UpdateAction(None, assignments)) =>
+ case Seq(UpdateAction(None, assignments, _)) =>
assignments match {
case Seq(
Assignment(i: AttributeReference, iValue: AttributeReference),
@@ -463,7 +463,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
}
matchedActions match {
- case Seq(UpdateAction(None, assignments)) =>
+ case Seq(UpdateAction(None, assignments, _)) =>
assignments match {
case Seq(
Assignment(c: AttributeReference, cValue: StaticInvoke),
@@ -531,7 +531,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
}
notMatchedBySourceActions match {
- case Seq(UpdateAction(None, assignments)) =>
+ case Seq(UpdateAction(None, assignments, _)) =>
assignments match {
case Seq(
Assignment(c: AttributeReference, cValue: StaticInvoke),
@@ -691,7 +691,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i"))
Seq(true, false).foreach { coerceNestedTypes =>
-
withSQLConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED.key ->
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
coerceNestedTypes.toString) {
val mergeStmt =
s"""MERGE INTO nested_struct_table t USING nested_struct_table
src
@@ -745,7 +745,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
val actions = if (matchedActions.nonEmpty) matchedActions else
notMatchedBySourceActions
actions match {
- case Seq(UpdateAction(_, assignments)) =>
+ case Seq(UpdateAction(_, assignments, _)) =>
assignments match {
case Seq(
Assignment(
@@ -858,7 +858,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i"))
Seq(true, false).foreach { coerceNestedTypes =>
-
withSQLConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED.key ->
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
coerceNestedTypes.toString) {
val mergeStmt =
s"""MERGE INTO nested_struct_table t USING nested_struct_table
src
@@ -947,7 +947,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
| UPDATE SET t.i = DEFAULT""".stripMargin)
matchedActions match {
- case Seq(UpdateAction(None, assignments)) =>
+ case Seq(UpdateAction(None, assignments, _)) =>
assignments match {
case Seq(
Assignment(b: AttributeReference, bValue: AttributeReference),
@@ -1001,7 +1001,7 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
}
notMatchedBySourceActions match {
- case Seq(UpdateAction(None, assignments)) =>
+ case Seq(UpdateAction(None, assignments, _)) =>
assignments match {
case Seq(
Assignment(b: AttributeReference, bValue: AttributeReference),
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
index dfd24a1ebe97..fb56cecb05ae 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
@@ -1703,12 +1703,12 @@ class PlanResolutionSuite extends SharedSparkSession
with AnalysisTest {
mergeCondition,
Seq(DeleteAction(Some(EqualTo(dl: AttributeReference,
StringLiteral("delete")))),
UpdateAction(Some(EqualTo(ul: AttributeReference,
StringLiteral("update"))),
- updateAssigns)),
+ updateAssigns, _)),
Seq(InsertAction(Some(EqualTo(il: AttributeReference,
StringLiteral("insert"))),
insertAssigns)),
Seq(DeleteAction(Some(EqualTo(ndl: AttributeReference,
StringLiteral("delete")))),
UpdateAction(Some(EqualTo(nul: AttributeReference,
StringLiteral("update"))),
- notMatchedBySourceUpdateAssigns)),
+ notMatchedBySourceUpdateAssigns, _)),
withSchemaEvolution) =>
checkMergeConditionResolution(target, source, mergeCondition)
checkMatchedClausesResolution(target, source, Some(dl), Some(ul),
updateAssigns)
@@ -1737,7 +1737,7 @@ class PlanResolutionSuite extends SharedSparkSession with
AnalysisTest {
mergeCondition,
Seq(DeleteAction(Some(EqualTo(dl: AttributeReference,
StringLiteral("delete")))),
UpdateAction(Some(EqualTo(ul: AttributeReference,
- StringLiteral("update"))), updateAssigns)),
+ StringLiteral("update"))), updateAssigns, _)),
Seq(InsertAction(Some(EqualTo(il: AttributeReference,
StringLiteral("insert"))),
insertAssigns)),
Seq(),
@@ -1766,7 +1766,7 @@ class PlanResolutionSuite extends SharedSparkSession with
AnalysisTest {
SubqueryAlias(AliasIdentifier("target", Seq()),
AsDataSourceV2Relation(target)),
SubqueryAlias(AliasIdentifier("source", Seq()),
AsDataSourceV2Relation(source)),
mergeCondition,
- Seq(UpdateAction(None, updateAssigns)),
+ Seq(UpdateAction(None, updateAssigns, _)),
Seq(InsertAction(None, insertAssigns)),
Seq(),
withSchemaEvolution) =>
@@ -1797,10 +1797,10 @@ class PlanResolutionSuite extends SharedSparkSession
with AnalysisTest {
SubqueryAlias(AliasIdentifier("target", Seq()),
AsDataSourceV2Relation(target)),
SubqueryAlias(AliasIdentifier("source", Seq()),
AsDataSourceV2Relation(source)),
mergeCondition,
- Seq(DeleteAction(Some(_)), UpdateAction(None, updateAssigns)),
+ Seq(DeleteAction(Some(_)), UpdateAction(None, updateAssigns, _)),
Seq(InsertAction(None, insertAssigns)),
Seq(DeleteAction(Some(EqualTo(_: AttributeReference,
StringLiteral("delete")))),
- UpdateAction(None, notMatchedBySourceUpdateAssigns)),
+ UpdateAction(None, notMatchedBySourceUpdateAssigns, _)),
withSchemaEvolution) =>
checkMergeConditionResolution(target, source, mergeCondition)
checkMatchedClausesResolution(target, source, None, None,
updateAssigns)
@@ -1832,12 +1832,12 @@ class PlanResolutionSuite extends SharedSparkSession
with AnalysisTest {
mergeCondition,
Seq(DeleteAction(Some(EqualTo(dl: AttributeReference,
StringLiteral("delete")))),
UpdateAction(Some(EqualTo(ul: AttributeReference,
StringLiteral("update"))),
- updateAssigns)),
+ updateAssigns, _)),
Seq(InsertAction(Some(EqualTo(il: AttributeReference,
StringLiteral("insert"))),
insertAssigns)),
Seq(DeleteAction(Some(EqualTo(ndl: AttributeReference,
StringLiteral("delete")))),
UpdateAction(Some(EqualTo(nul: AttributeReference,
StringLiteral("update"))),
- notMatchedBySourceUpdateAssigns)),
+ notMatchedBySourceUpdateAssigns, _)),
withSchemaEvolution) =>
checkMergeConditionResolution(target, source, mergeCondition)
checkMatchedClausesResolution(target, source, Some(dl), Some(ul),
updateAssigns)
@@ -1871,12 +1871,12 @@ class PlanResolutionSuite extends SharedSparkSession
with AnalysisTest {
mergeCondition,
Seq(DeleteAction(Some(EqualTo(dl: AttributeReference,
StringLiteral("delete")))),
UpdateAction(Some(EqualTo(ul: AttributeReference,
StringLiteral("update"))),
- updateAssigns)),
+ updateAssigns, _)),
Seq(InsertAction(Some(EqualTo(il: AttributeReference,
StringLiteral("insert"))),
insertAssigns)),
Seq(DeleteAction(Some(EqualTo(ndl: AttributeReference,
StringLiteral("delete")))),
UpdateAction(Some(EqualTo(nul: AttributeReference,
StringLiteral("update"))),
- notMatchedBySourceUpdateAssigns)),
+ notMatchedBySourceUpdateAssigns, _)),
withSchemaEvolution) =>
checkMergeConditionResolution(target, source, mergeCondition)
checkMatchedClausesResolution(target, source, Some(dl), Some(ul),
updateAssigns)
@@ -1927,7 +1927,7 @@ class PlanResolutionSuite extends SharedSparkSession with
AnalysisTest {
case UpdateAction(Some(EqualTo(_: AttributeReference,
StringLiteral("update"))),
Seq(
Assignment(_: AttributeReference, Literal(null, StringType)),
- Assignment(_: AttributeReference, _: AttributeReference))) =>
+ Assignment(_: AttributeReference, _: AttributeReference)),
_) =>
case other => fail("unexpected second matched action " + other)
}
assert(m.notMatchedActions.length == 1)
@@ -1947,7 +1947,7 @@ class PlanResolutionSuite extends SharedSparkSession with
AnalysisTest {
}
m.notMatchedBySourceActions(1) match {
case UpdateAction(Some(EqualTo(_: AttributeReference,
StringLiteral("update"))),
- Seq(Assignment(_: AttributeReference, Literal(null,
StringType)))) =>
+ Seq(Assignment(_: AttributeReference, Literal(null,
StringType))), _) =>
case other =>
fail("unexpected second not matched by source action " + other)
}
@@ -1999,7 +1999,7 @@ class PlanResolutionSuite extends SharedSparkSession with
AnalysisTest {
val second = m.matchedActions(1)
second match {
case UpdateAction(Some(EqualTo(_: AttributeReference, Literal(31,
IntegerType))),
- Seq(Assignment(_: AttributeReference, Literal(42, IntegerType)))) =>
+ Seq(Assignment(_: AttributeReference, Literal(42, IntegerType))), _)
=>
case other => fail("unexpected second matched action " + other)
}
assert(m.notMatchedActions.length == 1)
@@ -2017,7 +2017,7 @@ class PlanResolutionSuite extends SharedSparkSession with
AnalysisTest {
}
m.notMatchedBySourceActions(1) match {
case UpdateAction(Some(EqualTo(_: AttributeReference, Literal(31,
IntegerType))),
- Seq(Assignment(_: AttributeReference, Literal(42, IntegerType)))) =>
+ Seq(Assignment(_: AttributeReference, Literal(42, IntegerType))), _)
=>
case other => fail("unexpected second not matched by source action "
+ other)
}
assert(m.withSchemaEvolution === false)
@@ -2158,11 +2158,11 @@ class PlanResolutionSuite extends SharedSparkSession
with AnalysisTest {
AsDataSourceV2Relation(target),
AsDataSourceV2Relation(source),
_,
- Seq(DeleteAction(Some(_)), UpdateAction(None, firstUpdateAssigns)),
+ Seq(DeleteAction(Some(_)), UpdateAction(None, firstUpdateAssigns,
_)),
Seq(InsertAction(
Some(EqualTo(il: AttributeReference, StringLiteral("a"))),
insertAssigns)),
- Seq(DeleteAction(Some(_)), UpdateAction(None,
secondUpdateAssigns)),
+ Seq(DeleteAction(Some(_)), UpdateAction(None, secondUpdateAssigns,
_)),
withSchemaEvolution) =>
val ti = target.output.find(_.name == "i").get
val ts = target.output.find(_.name == "s").get
@@ -2282,7 +2282,7 @@ class PlanResolutionSuite extends SharedSparkSession with
AnalysisTest {
}
notMatchedBySourceActions(1) match {
case UpdateAction(Some(EqualTo(ul: AttributeReference,
StringLiteral("a"))),
- Seq(Assignment(us: AttributeReference, IntegerLiteral(1)))) =>
+ Seq(Assignment(us: AttributeReference, IntegerLiteral(1))), _)
=>
// UPDATE condition and assignment are resolved with target
table only, so column `s`
// and `i` are not ambiguous.
val ts = target.output.find(_.name == "s").get
@@ -2342,7 +2342,7 @@ class PlanResolutionSuite extends SharedSparkSession with
AnalysisTest {
AsDataSourceV2Relation(target),
AsDataSourceV2Relation(source),
EqualTo(IntegerLiteral(1), IntegerLiteral(1)),
- Seq(UpdateAction(None, updateAssigns)), // Matched actions
+ Seq(UpdateAction(None, updateAssigns, _)), // Matched actions
Seq(), // Not matched actions
Seq(), // Not matched by source actions
withSchemaEvolution) =>
@@ -2395,7 +2395,7 @@ class PlanResolutionSuite extends SharedSparkSession with
AnalysisTest {
assert(m.matchedActions.length == 1)
m.matchedActions.head match {
case UpdateAction(_, Seq(
- Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) =>
+ Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke)),
_) =>
assert(s1.arguments.length == 2)
assert(s1.functionName == "charTypeWriteSideCheck")
assert(s2.arguments.length == 2)
@@ -2421,7 +2421,7 @@ class PlanResolutionSuite extends SharedSparkSession with
AnalysisTest {
assert(m.notMatchedBySourceActions.length == 1)
m.notMatchedBySourceActions.head match {
case UpdateAction(_, Seq(
- Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) =>
+ Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke)),
_) =>
assert(s1.arguments.length == 2)
assert(s1.functionName == "charTypeWriteSideCheck")
assert(s2.arguments.length == 2)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]