dtenedor commented on code in PR #36415:
URL: https://github.com/apache/spark/pull/36415#discussion_r865109560


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala:
##########
@@ -55,11 +55,21 @@ object ResolveDefaultColumns {
   // Name of attributes representing explicit references to the value stored 
in the above
   // CURRENT_DEFAULT_COLUMN_METADATA.
   val CURRENT_DEFAULT_COLUMN_NAME = "DEFAULT"
-  // Return a more descriptive error message if the user tries to nest the 
DEFAULT column reference
-  // inside some other expression, such as DEFAULT + 1 (this is not allowed).
-  val DEFAULTS_IN_EXPRESSIONS_ERROR = "Failed to execute INSERT INTO command 
because the " +
-    "VALUES list contains a DEFAULT column reference as part of another 
expression; this is " +
-    "not allowed"
+  // Return a more descriptive error message if the user tries to nest a 
DEFAULT column reference
+  // inside some other expression (such as DEFAULT + 1) in an INSERT INTO 
command's VALUES list;
+  // this is not allowed.
+  val DEFAULTS_IN_COMPLEX_EXPRESSIONS_IN_INSERT_VALUES = "Failed to execute 
INSERT INTO command " +

Review Comment:
   SG, done.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala:
##########
@@ -48,62 +50,145 @@ import org.apache.spark.sql.types._
  * @param catalog  the catalog to use for looking up the schema of INSERT INTO 
table objects.
  */
 case class ResolveDefaultColumns(
-  analyzer: Analyzer,
-  catalog: SessionCatalog) extends Rule[LogicalPlan] {
-
-  // This field stores the enclosing INSERT INTO command, once we find one.
-  var enclosingInsert: Option[InsertIntoStatement] = None
-  // This field stores the schema of the target table of the above command.
-  var insertTableSchemaWithoutPartitionColumns: Option[StructType] = None
-
+    analyzer: Analyzer,
+    catalog: SessionCatalog) extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
-    // Initialize by clearing our reference to the enclosing INSERT INTO 
command.
-    enclosingInsert = None
-    insertTableSchemaWithoutPartitionColumns = None
-    // Traverse the logical query plan in preorder (top-down).
     plan.resolveOperatorsWithPruning(
       (_ => SQLConf.get.enableDefaultColumns), ruleId) {
-      case i@InsertIntoStatement(_, _, _, _, _, _)
-        if i.query.collectFirst { case u: UnresolvedInlineTable
-          if u.rows.nonEmpty && u.rows.forall(_.size == u.rows(0).size) => u
-        }.isDefined =>
-        enclosingInsert = Some(i)
-        insertTableSchemaWithoutPartitionColumns = 
getInsertTableSchemaWithoutPartitionColumns
-        val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i)
-        regenerated
-
-      case table: UnresolvedInlineTable
-        if enclosingInsert.isDefined =>
-        val expanded: UnresolvedInlineTable = 
addMissingDefaultColumnValues(table).getOrElse(table)
-        val replaced: LogicalPlan =
-          replaceExplicitDefaultColumnValues(analyzer, 
expanded).getOrElse(table)
-        replaced
-
+      case i: InsertIntoStatement if insertsFromInlineTable(i) =>
+        resolveDefaultColumnsForInsertFromInlineTable(i)
       case i@InsertIntoStatement(_, _, _, project: Project, _, _)
         if !project.projectList.exists(_.isInstanceOf[Star]) =>
-        enclosingInsert = Some(i)
-        insertTableSchemaWithoutPartitionColumns = 
getInsertTableSchemaWithoutPartitionColumns
-        val expanded: Project = 
addMissingDefaultColumnValues(project).getOrElse(project)
-        val replaced: Option[LogicalPlan] = 
replaceExplicitDefaultColumnValues(analyzer, expanded)
-        val updated: InsertIntoStatement =
-          if (replaced.isDefined) i.copy(query = replaced.get) else i
-        val regenerated: InsertIntoStatement = 
regenerateUserSpecifiedCols(updated)
-        regenerated
+        resolveDefaultColumnsForInsertFromProject(i)
+      case u: UpdateTable =>
+        resolveDefaultColumnsForUpdate(u)
+    }
+  }
+
+  /**
+   * Checks if a logical plan is an INSERT INTO command where the inserted 
data comes from a VALUES
+   * list, with possible projection(s), aggregate(s), and/or alias(es) in 
between.
+   */
+  private def insertsFromInlineTable(i: InsertIntoStatement): Boolean = {
+    var query = i.query
+    while (query.children.size == 1) {
+      query match {
+        case _: Project | _: Aggregate | _: SubqueryAlias =>
+          query = query.children(0)
+        case _ =>
+          return false
+      }
+    }
+    query match {
+      case u: UnresolvedInlineTable
+        if u.rows.nonEmpty && u.rows.forall(_.size == u.rows(0).size) =>
+        true
+      case _ =>
+        false
     }
   }
 
-  // Helper method to regenerate user-specified columns of an 
InsertIntoStatement based on the names
-  // in the insertTableSchemaWithoutPartitionColumns field of this class.
-  private def regenerateUserSpecifiedCols(i: InsertIntoStatement): 
InsertIntoStatement = {
-    if (i.userSpecifiedCols.nonEmpty && 
insertTableSchemaWithoutPartitionColumns.isDefined) {
+  /**
+   * Resolves DEFAULT column references for an INSERT INTO command satisfying 
the
+   * [[insertsFromInlineTable]] method.
+   */
+  private def resolveDefaultColumnsForInsertFromInlineTable(i: 
InsertIntoStatement): LogicalPlan = {
+    val children = mutable.Buffer.empty[LogicalPlan]
+    var node = i.query
+    while (node.children.size == 1) {
+      children.append(node)
+      node = node.children(0)
+    }
+    val table = node.asInstanceOf[UnresolvedInlineTable]
+    val insertTableSchemaWithoutPartitionColumns: StructType =
+      getInsertTableSchemaWithoutPartitionColumns(i)
+        .getOrElse(return i)
+    val regenerated: InsertIntoStatement =
+      regenerateUserSpecifiedCols(i, insertTableSchemaWithoutPartitionColumns)
+    val expanded: UnresolvedInlineTable =
+      addMissingDefaultValuesForInsertFromInlineTable(
+        table, insertTableSchemaWithoutPartitionColumns)
+    val replaced: LogicalPlan =
+      replaceExplicitDefaultValuesForInputOfInsertInto(
+        analyzer, insertTableSchemaWithoutPartitionColumns, expanded)
+        .getOrElse(return i)
+    node = replaced
+    for (child <- children.reverse) {
+      node = child.withNewChildren(Seq(node))
+    }
+    regenerated.copy(query = node)
+  }
+
+  /**
+   * Resolves DEFAULT column references for an INSERT INTO command whose query 
is a general
+   * projection.
+   */
+  private def resolveDefaultColumnsForInsertFromProject(i: 
InsertIntoStatement): LogicalPlan = {
+    val insertTableSchemaWithoutPartitionColumns: StructType =
+      getInsertTableSchemaWithoutPartitionColumns(i)
+        .getOrElse(return i)
+    val regenerated: InsertIntoStatement =
+      regenerateUserSpecifiedCols(i, insertTableSchemaWithoutPartitionColumns)
+    val project: Project = i.query.asInstanceOf[Project]
+    val expanded: Project =
+      addMissingDefaultValuesForInsertFromProject(
+        project, insertTableSchemaWithoutPartitionColumns)
+    val replaced: LogicalPlan =
+      replaceExplicitDefaultValuesForInputOfInsertInto(
+        analyzer, insertTableSchemaWithoutPartitionColumns, expanded)
+        .getOrElse(return i)
+    regenerated.copy(query = replaced)
+  }
+
+  /**
+   * Resolves DEFAULT column references for an UPDATE command.
+   */
+  private def resolveDefaultColumnsForUpdate(u: UpdateTable): LogicalPlan = {
+    // Return a more descriptive error message if the user tries to use a 
DEFAULT column reference
+    // inside an UPDATE command's WHERE clause; this is not allowed.
+    u.condition.map { c: Expression =>
+      if (c.find(isExplicitDefaultColumn).isDefined) {
+        throw new AnalysisException(DEFAULTS_IN_UPDATE_WHERE_CLAUSE)
+      }
+    }
+    val schema: StructType = getSchemaForTargetTable(u.table).getOrElse(return 
u)
+    val defaultExpressions: Seq[Expression] = schema.fields.map {
+      case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
+        analyze(analyzer, f, "UPDATE")
+      case _ => Literal(null)
+    }
+    // Create a map from each column name in the target table to its DEFAULT 
expression.
+    val columnNamesToExpressions: Map[String, Expression] =
+      schema.fields.zip(defaultExpressions).map {
+        case (field, expr) =>
+          field.name.map(n => if (SQLConf.get.caseSensitiveAnalysis) n.toLower 
else n) -> expr
+      }.toMap
+    // For each assignment in the UPDATE command's SET clause with a DEFAULT 
column reference on the
+    // right-hand side, look up the corresponding expression from the above 
map.
+    val newAssignments: Seq[Assignment] =
+    replaceExplicitDefaultValuesForUpdateAssignments(u.assignments, 
columnNamesToExpressions)
+      .getOrElse(return u)

Review Comment:
   Thanks for pointing this out, there is some weird behavior of `return` in 
closures! Updated this.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala:
##########
@@ -48,62 +50,145 @@ import org.apache.spark.sql.types._
  * @param catalog  the catalog to use for looking up the schema of INSERT INTO 
table objects.
  */
 case class ResolveDefaultColumns(
-  analyzer: Analyzer,
-  catalog: SessionCatalog) extends Rule[LogicalPlan] {
-
-  // This field stores the enclosing INSERT INTO command, once we find one.
-  var enclosingInsert: Option[InsertIntoStatement] = None
-  // This field stores the schema of the target table of the above command.
-  var insertTableSchemaWithoutPartitionColumns: Option[StructType] = None
-
+    analyzer: Analyzer,
+    catalog: SessionCatalog) extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
-    // Initialize by clearing our reference to the enclosing INSERT INTO 
command.
-    enclosingInsert = None
-    insertTableSchemaWithoutPartitionColumns = None
-    // Traverse the logical query plan in preorder (top-down).
     plan.resolveOperatorsWithPruning(
       (_ => SQLConf.get.enableDefaultColumns), ruleId) {
-      case i@InsertIntoStatement(_, _, _, _, _, _)
-        if i.query.collectFirst { case u: UnresolvedInlineTable
-          if u.rows.nonEmpty && u.rows.forall(_.size == u.rows(0).size) => u
-        }.isDefined =>
-        enclosingInsert = Some(i)
-        insertTableSchemaWithoutPartitionColumns = 
getInsertTableSchemaWithoutPartitionColumns
-        val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i)
-        regenerated
-
-      case table: UnresolvedInlineTable
-        if enclosingInsert.isDefined =>
-        val expanded: UnresolvedInlineTable = 
addMissingDefaultColumnValues(table).getOrElse(table)
-        val replaced: LogicalPlan =
-          replaceExplicitDefaultColumnValues(analyzer, 
expanded).getOrElse(table)
-        replaced
-
+      case i: InsertIntoStatement if insertsFromInlineTable(i) =>
+        resolveDefaultColumnsForInsertFromInlineTable(i)
       case i@InsertIntoStatement(_, _, _, project: Project, _, _)
         if !project.projectList.exists(_.isInstanceOf[Star]) =>
-        enclosingInsert = Some(i)
-        insertTableSchemaWithoutPartitionColumns = 
getInsertTableSchemaWithoutPartitionColumns
-        val expanded: Project = 
addMissingDefaultColumnValues(project).getOrElse(project)
-        val replaced: Option[LogicalPlan] = 
replaceExplicitDefaultColumnValues(analyzer, expanded)
-        val updated: InsertIntoStatement =
-          if (replaced.isDefined) i.copy(query = replaced.get) else i
-        val regenerated: InsertIntoStatement = 
regenerateUserSpecifiedCols(updated)
-        regenerated
+        resolveDefaultColumnsForInsertFromProject(i)
+      case u: UpdateTable =>
+        resolveDefaultColumnsForUpdate(u)
+    }
+  }
+
+  /**
+   * Checks if a logical plan is an INSERT INTO command where the inserted 
data comes from a VALUES
+   * list, with possible projection(s), aggregate(s), and/or alias(es) in 
between.
+   */
+  private def insertsFromInlineTable(i: InsertIntoStatement): Boolean = {
+    var query = i.query
+    while (query.children.size == 1) {
+      query match {
+        case _: Project | _: Aggregate | _: SubqueryAlias =>
+          query = query.children(0)
+        case _ =>
+          return false
+      }
+    }
+    query match {
+      case u: UnresolvedInlineTable
+        if u.rows.nonEmpty && u.rows.forall(_.size == u.rows(0).size) =>
+        true
+      case _ =>
+        false
     }
   }
 
-  // Helper method to regenerate user-specified columns of an 
InsertIntoStatement based on the names
-  // in the insertTableSchemaWithoutPartitionColumns field of this class.
-  private def regenerateUserSpecifiedCols(i: InsertIntoStatement): 
InsertIntoStatement = {
-    if (i.userSpecifiedCols.nonEmpty && 
insertTableSchemaWithoutPartitionColumns.isDefined) {
+  /**
+   * Resolves DEFAULT column references for an INSERT INTO command satisfying 
the
+   * [[insertsFromInlineTable]] method.
+   */
+  private def resolveDefaultColumnsForInsertFromInlineTable(i: 
InsertIntoStatement): LogicalPlan = {
+    val children = mutable.Buffer.empty[LogicalPlan]
+    var node = i.query
+    while (node.children.size == 1) {
+      children.append(node)
+      node = node.children(0)
+    }
+    val table = node.asInstanceOf[UnresolvedInlineTable]
+    val insertTableSchemaWithoutPartitionColumns: StructType =
+      getInsertTableSchemaWithoutPartitionColumns(i)
+        .getOrElse(return i)
+    val regenerated: InsertIntoStatement =
+      regenerateUserSpecifiedCols(i, insertTableSchemaWithoutPartitionColumns)
+    val expanded: UnresolvedInlineTable =
+      addMissingDefaultValuesForInsertFromInlineTable(
+        table, insertTableSchemaWithoutPartitionColumns)
+    val replaced: LogicalPlan =
+      replaceExplicitDefaultValuesForInputOfInsertInto(
+        analyzer, insertTableSchemaWithoutPartitionColumns, expanded)
+        .getOrElse(return i)
+    node = replaced
+    for (child <- children.reverse) {
+      node = child.withNewChildren(Seq(node))
+    }
+    regenerated.copy(query = node)
+  }
+
+  /**
+   * Resolves DEFAULT column references for an INSERT INTO command whose query 
is a general
+   * projection.
+   */
+  private def resolveDefaultColumnsForInsertFromProject(i: 
InsertIntoStatement): LogicalPlan = {
+    val insertTableSchemaWithoutPartitionColumns: StructType =
+      getInsertTableSchemaWithoutPartitionColumns(i)
+        .getOrElse(return i)
+    val regenerated: InsertIntoStatement =
+      regenerateUserSpecifiedCols(i, insertTableSchemaWithoutPartitionColumns)
+    val project: Project = i.query.asInstanceOf[Project]
+    val expanded: Project =
+      addMissingDefaultValuesForInsertFromProject(
+        project, insertTableSchemaWithoutPartitionColumns)
+    val replaced: LogicalPlan =
+      replaceExplicitDefaultValuesForInputOfInsertInto(
+        analyzer, insertTableSchemaWithoutPartitionColumns, expanded)
+        .getOrElse(return i)
+    regenerated.copy(query = replaced)
+  }
+
+  /**
+   * Resolves DEFAULT column references for an UPDATE command.
+   */
+  private def resolveDefaultColumnsForUpdate(u: UpdateTable): LogicalPlan = {
+    // Return a more descriptive error message if the user tries to use a 
DEFAULT column reference
+    // inside an UPDATE command's WHERE clause; this is not allowed.
+    u.condition.map { c: Expression =>
+      if (c.find(isExplicitDefaultColumn).isDefined) {
+        throw new AnalysisException(DEFAULTS_IN_UPDATE_WHERE_CLAUSE)
+      }
+    }
+    val schema: StructType = getSchemaForTargetTable(u.table).getOrElse(return 
u)
+    val defaultExpressions: Seq[Expression] = schema.fields.map {
+      case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
+        analyze(analyzer, f, "UPDATE")
+      case _ => Literal(null)
+    }
+    // Create a map from each column name in the target table to its DEFAULT 
expression.
+    val columnNamesToExpressions: Map[String, Expression] =

Review Comment:
   Sounds good, this way we can make sure to use the same case-sensitive lookup 
logic everywhere. Done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to