This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a1aa200bdf3 [SPARK-38838][SQL] Refactor ResolveDefaultColumns.scala to 
simplify helper methods
a1aa200bdf3 is described below

commit a1aa200bdf32e55ea3b1f220da882b29a7a2bf9b
Author: Daniel Tenedorio <daniel.tenedo...@databricks.com>
AuthorDate: Wed May 4 15:06:07 2022 +0800

    [SPARK-38838][SQL] Refactor ResolveDefaultColumns.scala to simplify helper 
methods
    
    ### What changes were proposed in this pull request?
    
    Refactor ResolveDefaultColumns.scala to simplify helper methods:
    
    1. Remove global variables entirely and pass all state through the method 
calls instead.
    2. Update method names to indicate which commands they are related to.
    
    ### Why are the changes needed?
    
    This simplifies work for future DEFAULT column development.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, it's just a refactoring change.
    
    ### How was this patch tested?
    
    As a refactoring-only change, it is covered with existing test coverage.
    
    Closes #36398 from dtenedor/refactor.
    
    Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../catalyst/analysis/ResolveDefaultColumns.scala  | 340 +++++++++++++--------
 1 file changed, 209 insertions(+), 131 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
index 422a1e422be..ffbe18a7dfa 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import scala.collection.mutable
+
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, 
UnresolvedCatalogRelation}
@@ -48,62 +50,112 @@ import org.apache.spark.sql.types._
  * @param catalog  the catalog to use for looking up the schema of INSERT INTO 
table objects.
  */
 case class ResolveDefaultColumns(
-  analyzer: Analyzer,
-  catalog: SessionCatalog) extends Rule[LogicalPlan] {
-
-  // This field stores the enclosing INSERT INTO command, once we find one.
-  var enclosingInsert: Option[InsertIntoStatement] = None
-  // This field stores the schema of the target table of the above command.
-  var insertTableSchemaWithoutPartitionColumns: Option[StructType] = None
-
+    analyzer: Analyzer,
+    catalog: SessionCatalog) extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
-    // Initialize by clearing our reference to the enclosing INSERT INTO 
command.
-    enclosingInsert = None
-    insertTableSchemaWithoutPartitionColumns = None
-    // Traverse the logical query plan in preorder (top-down).
     plan.resolveOperatorsWithPruning(
       (_ => SQLConf.get.enableDefaultColumns), ruleId) {
-      case i@InsertIntoStatement(_, _, _, _, _, _)
-        if i.query.collectFirst { case u: UnresolvedInlineTable
-          if u.rows.nonEmpty && u.rows.forall(_.size == u.rows(0).size) => u
-        }.isDefined =>
-        enclosingInsert = Some(i)
-        insertTableSchemaWithoutPartitionColumns = 
getInsertTableSchemaWithoutPartitionColumns
-        val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i)
-        regenerated
-
-      case table: UnresolvedInlineTable
-        if enclosingInsert.isDefined =>
-        val expanded: UnresolvedInlineTable = 
addMissingDefaultColumnValues(table).getOrElse(table)
-        val replaced: LogicalPlan =
-          replaceExplicitDefaultColumnValues(analyzer, 
expanded).getOrElse(table)
-        replaced
-
+      case i: InsertIntoStatement if insertsFromInlineTable(i) =>
+        resolveDefaultColumnsForInsertFromInlineTable(i)
       case i@InsertIntoStatement(_, _, _, project: Project, _, _)
         if !project.projectList.exists(_.isInstanceOf[Star]) =>
-        enclosingInsert = Some(i)
-        insertTableSchemaWithoutPartitionColumns = 
getInsertTableSchemaWithoutPartitionColumns
-        val expanded: Project = 
addMissingDefaultColumnValues(project).getOrElse(project)
-        val replaced: Option[LogicalPlan] = 
replaceExplicitDefaultColumnValues(analyzer, expanded)
-        val updated: InsertIntoStatement =
-          if (replaced.isDefined) i.copy(query = replaced.get) else i
-        val regenerated: InsertIntoStatement = 
regenerateUserSpecifiedCols(updated)
-        regenerated
+        resolveDefaultColumnsForInsertFromProject(i)
     }
   }
 
-  // Helper method to regenerate user-specified columns of an 
InsertIntoStatement based on the names
-  // in the insertTableSchemaWithoutPartitionColumns field of this class.
-  private def regenerateUserSpecifiedCols(i: InsertIntoStatement): 
InsertIntoStatement = {
-    if (i.userSpecifiedCols.nonEmpty && 
insertTableSchemaWithoutPartitionColumns.isDefined) {
+  /**
+   * Checks if a logical plan is an INSERT INTO command where the inserted 
data comes from a VALUES
+   * list, with possible projection(s), aggregate(s), and/or alias(es) in 
between.
+   */
+  private def insertsFromInlineTable(i: InsertIntoStatement): Boolean = {
+    var query = i.query
+    while (query.children.size == 1) {
+      query match {
+        case _: Project | _: Aggregate | _: SubqueryAlias =>
+          query = query.children(0)
+        case _ =>
+          return false
+      }
+    }
+    query match {
+      case u: UnresolvedInlineTable
+        if u.rows.nonEmpty && u.rows.forall(_.size == u.rows(0).size) =>
+        true
+      case _ =>
+        false
+    }
+  }
+
+  /**
+   * Resolves DEFAULT column references for an INSERT INTO command satisfying 
the
+   * [[insertsFromInlineTable]] method.
+   */
+  private def resolveDefaultColumnsForInsertFromInlineTable(i: 
InsertIntoStatement): LogicalPlan = {
+    val children = mutable.Buffer.empty[LogicalPlan]
+    var node = i.query
+    while (node.children.size == 1) {
+      children.append(node)
+      node = node.children(0)
+    }
+    val table = node.asInstanceOf[UnresolvedInlineTable]
+    val insertTableSchemaWithoutPartitionColumns: StructType =
+      getInsertTableSchemaWithoutPartitionColumns(i)
+        .getOrElse(return i)
+    val regenerated: InsertIntoStatement =
+      regenerateUserSpecifiedCols(i, insertTableSchemaWithoutPartitionColumns)
+    val expanded: UnresolvedInlineTable =
+      addMissingDefaultValuesForInsertFromInlineTable(
+        table, insertTableSchemaWithoutPartitionColumns)
+    val replaced: LogicalPlan =
+      replaceExplicitDefaultValuesForInputOfInsertInto(
+        analyzer, insertTableSchemaWithoutPartitionColumns, expanded)
+        .getOrElse(return i)
+    node = replaced
+    for (child <- children.reverse) {
+      node = child.withNewChildren(Seq(node))
+    }
+    regenerated.copy(query = node)
+  }
+
+  /**
+   * Resolves DEFAULT column references for an INSERT INTO command whose query 
is a general
+   * projection.
+   */
+  private def resolveDefaultColumnsForInsertFromProject(i: 
InsertIntoStatement): LogicalPlan = {
+    val insertTableSchemaWithoutPartitionColumns: StructType =
+      getInsertTableSchemaWithoutPartitionColumns(i)
+        .getOrElse(return i)
+    val regenerated: InsertIntoStatement =
+      regenerateUserSpecifiedCols(i, insertTableSchemaWithoutPartitionColumns)
+    val project: Project = i.query.asInstanceOf[Project]
+    val expanded: Project =
+      addMissingDefaultValuesForInsertFromProject(
+        project, insertTableSchemaWithoutPartitionColumns)
+    val replaced: LogicalPlan =
+      replaceExplicitDefaultValuesForInputOfInsertInto(
+        analyzer, insertTableSchemaWithoutPartitionColumns, expanded)
+        .getOrElse(return i)
+    regenerated.copy(query = replaced)
+  }
+
+  /**
+   * Regenerates user-specified columns of an InsertIntoStatement based on the 
names in the
+   * insertTableSchemaWithoutPartitionColumns field of this class.
+   */
+  private def regenerateUserSpecifiedCols(
+      i: InsertIntoStatement,
+      insertTableSchemaWithoutPartitionColumns: StructType): 
InsertIntoStatement = {
+    if (i.userSpecifiedCols.nonEmpty) {
       i.copy(
-        userSpecifiedCols = 
insertTableSchemaWithoutPartitionColumns.get.fields.map(_.name))
+        userSpecifiedCols = 
insertTableSchemaWithoutPartitionColumns.fields.map(_.name))
     } else {
       i
     }
   }
 
-  // Helper method to check if an expression is an explicit DEFAULT column 
reference.
+  /**
+   * Returns true if an expression is an explicit DEFAULT column reference.
+   */
   private def isExplicitDefaultColumn(expr: Expression): Boolean = expr match {
     case u: UnresolvedAttribute if 
u.name.equalsIgnoreCase(CURRENT_DEFAULT_COLUMN_NAME) => true
     case _ => false
@@ -112,51 +164,49 @@ case class ResolveDefaultColumns(
   /**
    * Updates an inline table to generate missing default column values.
    */
-  private def addMissingDefaultColumnValues(
-    table: UnresolvedInlineTable): Option[UnresolvedInlineTable] = {
-    assert(enclosingInsert.isDefined)
+  private def addMissingDefaultValuesForInsertFromInlineTable(
+      table: UnresolvedInlineTable,
+      insertTableSchemaWithoutPartitionColumns: StructType): 
UnresolvedInlineTable = {
     val numQueryOutputs: Int = table.rows(0).size
-    val schema = insertTableSchemaWithoutPartitionColumns.getOrElse(return 
None)
-    val newDefaultExpressions: Seq[Expression] = 
getDefaultExpressions(numQueryOutputs, schema)
+    val schema = insertTableSchemaWithoutPartitionColumns
+    val newDefaultExpressions: Seq[Expression] =
+      getDefaultExpressionsForInsert(numQueryOutputs, schema)
     val newNames: Seq[String] = schema.fields.drop(numQueryOutputs).map { 
_.name }
-    if (newDefaultExpressions.nonEmpty) {
-      Some(table.copy(
-        names = table.names ++ newNames,
-        rows = table.rows.map { row => row ++ newDefaultExpressions }))
-    } else {
-      None
-    }
+    table.copy(
+      names = table.names ++ newNames,
+      rows = table.rows.map { row => row ++ newDefaultExpressions })
   }
 
   /**
    * Adds a new expressions to a projection to generate missing default column 
values.
    */
-  private def addMissingDefaultColumnValues(project: Project): Option[Project] 
= {
+  private def addMissingDefaultValuesForInsertFromProject(
+      project: Project,
+      insertTableSchemaWithoutPartitionColumns: StructType): Project = {
     val numQueryOutputs: Int = project.projectList.size
-    val schema = insertTableSchemaWithoutPartitionColumns.getOrElse(return 
None)
-    val newDefaultExpressions: Seq[Expression] = 
getDefaultExpressions(numQueryOutputs, schema)
+    val schema = insertTableSchemaWithoutPartitionColumns
+    val newDefaultExpressions: Seq[Expression] =
+      getDefaultExpressionsForInsert(numQueryOutputs, schema)
     val newAliases: Seq[NamedExpression] =
       newDefaultExpressions.zip(schema.fields).map {
         case (expr, field) => Alias(expr, field.name)()
       }
-    if (newDefaultExpressions.nonEmpty) {
-      Some(project.copy(projectList = project.projectList ++ newAliases))
-    } else {
-      None
-    }
+    project.copy(projectList = project.projectList ++ newAliases)
   }
 
   /**
-   * This is a helper for the addMissingDefaultColumnValues methods above.
+   * This is a helper for the addMissingDefaultValuesForInsertFromInlineTable 
methods above.
    */
-  private def getDefaultExpressions(numQueryOutputs: Int, schema: StructType): 
Seq[Expression] = {
+  private def getDefaultExpressionsForInsert(
+      numQueryOutputs: Int,
+      schema: StructType): Seq[Expression] = {
     val remainingFields: Seq[StructField] = schema.fields.drop(numQueryOutputs)
     val numDefaultExpressionsToAdd = 
getStructFieldsForDefaultExpressions(remainingFields).size
     
Seq.fill(numDefaultExpressionsToAdd)(UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME))
   }
 
   /**
-   * This is a helper for the getDefaultExpressions methods above.
+   * This is a helper for the getDefaultExpressionsForInsert methods above.
    */
   private def getStructFieldsForDefaultExpressions(fields: Seq[StructField]): 
Seq[StructField] = {
     if (SQLConf.get.useNullsForMissingDefaultColumnValues) {
@@ -167,14 +217,14 @@ case class ResolveDefaultColumns(
   }
 
   /**
-   * Replaces unresolved DEFAULT column references with corresponding values 
in a logical plan.
+   * Replaces unresolved DEFAULT column references with corresponding values 
in an INSERT INTO
+   * command from a logical plan.
    */
-  private def replaceExplicitDefaultColumnValues(
-    analyzer: Analyzer,
-    input: LogicalPlan): Option[LogicalPlan] = {
-    assert(enclosingInsert.isDefined)
-    val schema = insertTableSchemaWithoutPartitionColumns.getOrElse(return 
None)
-    val columnNames: Seq[String] = schema.fields.map { _.name }
+  private def replaceExplicitDefaultValuesForInputOfInsertInto(
+      analyzer: Analyzer,
+      insertTableSchemaWithoutPartitionColumns: StructType,
+      input: LogicalPlan): Option[LogicalPlan] = {
+    val schema = insertTableSchemaWithoutPartitionColumns
     val defaultExpressions: Seq[Expression] = schema.fields.map {
       case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
         analyze(analyzer, f, "INSERT")
@@ -184,43 +234,40 @@ case class ResolveDefaultColumns(
     // If necessary, return a more descriptive error message if the user tries 
to nest the DEFAULT
     // column reference inside some other expression, such as DEFAULT + 1 
(this is not allowed).
     //
-    // Note that we don't need to check if 
"SQLConf.get.useNullsForMissingDefaultColumnValues"
-    // after this point because this method only takes responsibility to 
replace *existing*
-    // DEFAULT references. In contrast, the "getDefaultExpressions" method 
will check that config
+    // Note that we don't need to check if 
"SQLConf.get.useNullsForMissingDefaultColumnValues" after
+    // this point because this method only takes responsibility to replace 
*existing* DEFAULT
+    // references. In contrast, the "getDefaultExpressionsForInsert" method 
will check that config
     // and add new NULLs if needed.
     input match {
       case table: UnresolvedInlineTable =>
-        replaceExplicitDefaultColumnValues(defaultExpressions, table)
+        replaceExplicitDefaultValuesForInlineTable(defaultExpressions, table)
       case project: Project =>
-        replaceExplicitDefaultColumnValues(defaultExpressions, columnNames, 
project)
+        replaceExplicitDefaultValuesForProject(defaultExpressions, project)
     }
   }
 
   /**
    * Replaces unresolved DEFAULT column references with corresponding values 
in an inline table.
    */
-  private def replaceExplicitDefaultColumnValues(
-    defaultExpressions: Seq[Expression],
-    table: UnresolvedInlineTable): Option[LogicalPlan] = {
+  private def replaceExplicitDefaultValuesForInlineTable(
+      defaultExpressions: Seq[Expression],
+      table: UnresolvedInlineTable): Option[LogicalPlan] = {
     var replaced = false
-    val newRows: Seq[Seq[Expression]] = {
+    val updated: Seq[Seq[Expression]] = {
       table.rows.map { row: Seq[Expression] =>
         for {
           i <- 0 until row.size
           expr = row(i)
           defaultExpr = if (i < defaultExpressions.size) defaultExpressions(i) 
else Literal(null)
-        } yield expr match {
-          case u: UnresolvedAttribute if isExplicitDefaultColumn(u) =>
-            replaced = true
-            defaultExpr
-          case expr@_ if expr.find { isExplicitDefaultColumn }.isDefined =>
-            throw new AnalysisException(DEFAULTS_IN_EXPRESSIONS_ERROR)
-          case _ => expr
-        }
+        } yield replaceExplicitDefaultReferenceInExpression(
+          expr, defaultExpr, DEFAULTS_IN_EXPRESSIONS_ERROR, false).map { e =>
+          replaced = true
+          e
+        }.getOrElse(expr)
       }
     }
     if (replaced) {
-      Some(table.copy(rows = newRows))
+      Some(table.copy(rows = updated))
     } else {
       None
     }
@@ -229,28 +276,20 @@ case class ResolveDefaultColumns(
   /**
    * Replaces unresolved DEFAULT column references with corresponding values 
in a projection.
    */
-  private def replaceExplicitDefaultColumnValues(
-    defaultExpressions: Seq[Expression],
-    colNames: Seq[String],
-    project: Project): Option[LogicalPlan] = {
+  private def replaceExplicitDefaultValuesForProject(
+      defaultExpressions: Seq[Expression],
+      project: Project): Option[LogicalPlan] = {
     var replaced = false
     val updated: Seq[NamedExpression] = {
       for {
         i <- 0 until project.projectList.size
         projectExpr = project.projectList(i)
         defaultExpr = if (i < defaultExpressions.size) defaultExpressions(i) 
else Literal(null)
-        colName = if (i < colNames.size) colNames(i) else ""
-      } yield projectExpr match {
-        case Alias(u: UnresolvedAttribute, _) if isExplicitDefaultColumn(u) =>
-          replaced = true
-          Alias(defaultExpr, colName)()
-        case u: UnresolvedAttribute if isExplicitDefaultColumn(u) =>
-          replaced = true
-          Alias(defaultExpr, colName)()
-        case expr@_ if expr.find { isExplicitDefaultColumn }.isDefined =>
-          throw new AnalysisException(DEFAULTS_IN_EXPRESSIONS_ERROR)
-        case _ => projectExpr
-      }
+      } yield replaceExplicitDefaultReferenceInExpression(
+        projectExpr, defaultExpr, DEFAULTS_IN_EXPRESSIONS_ERROR, true).map { e 
=>
+        replaced = true
+        e.asInstanceOf[NamedExpression]
+      }.getOrElse(projectExpr)
     }
     if (replaced) {
       Some(project.copy(projectList = updated))
@@ -260,35 +299,48 @@ case class ResolveDefaultColumns(
   }
 
   /**
-   * Looks up the schema for the table object of an INSERT INTO statement from 
the catalog.
+   * Checks if a given input expression is an unresolved "DEFAULT" attribute 
reference.
+   *
+   * @param input the input expression to examine.
+   * @param defaultExpr the default to return if [[input]] is an unresolved 
"DEFAULT" reference.
+   * @param complexDefaultError error if [[input]] is a complex expression 
with "DEFAULT" inside.
+   * @param addAlias if true, wraps the result with an alias of the original 
default column name.
+   * @return [[defaultExpr]] if [[input]] is an unresolved "DEFAULT" attribute 
reference.
    */
-  private def getInsertTableSchemaWithoutPartitionColumns: Option[StructType] 
= {
-    assert(enclosingInsert.isDefined)
-    val tableName = enclosingInsert.get.table match {
-      case r: UnresolvedRelation => TableIdentifier(r.name)
-      case r: UnresolvedCatalogRelation => r.tableMeta.identifier
-      case _ => return None
-    }
-    // Lookup the relation from the catalog by name. This either succeeds or 
returns some "not
-    // found" error. In the latter cases, return out of this rule without 
changing anything and let
-    // the analyzer return a proper error message elsewhere.
-    val lookup: LogicalPlan = try {
-      catalog.lookupRelation(tableName)
-    } catch {
-      case _: AnalysisException => return None
-    }
-    val schema: StructType = lookup match {
-      case SubqueryAlias(_, r: UnresolvedCatalogRelation) =>
-        StructType(r.tableMeta.schema.fields.dropRight(
-          enclosingInsert.get.partitionSpec.size))
-      case SubqueryAlias(_, r: View) if r.isTempView =>
-        StructType(r.schema.fields.dropRight(
-          enclosingInsert.get.partitionSpec.size))
-      case _ => return None
+  private def replaceExplicitDefaultReferenceInExpression(
+      input: Expression,
+      defaultExpr: Expression,
+      complexDefaultError: String,
+      addAlias: Boolean): Option[Expression] = {
+    input match {
+      case a@Alias(u: UnresolvedAttribute, _)
+        if isExplicitDefaultColumn(u) =>
+        Some(Alias(defaultExpr, a.name)())
+      case u: UnresolvedAttribute
+        if isExplicitDefaultColumn(u) =>
+        if (addAlias) {
+          Some(Alias(defaultExpr, u.name)())
+        } else {
+          Some(defaultExpr)
+        }
+      case expr@_
+        if expr.find(isExplicitDefaultColumn).isDefined =>
+        throw new AnalysisException(complexDefaultError)
+      case _ =>
+        None
     }
+  }
+
+  /**
+   * Looks up the schema for the table object of an INSERT INTO statement from 
the catalog.
+   */
+  private def getInsertTableSchemaWithoutPartitionColumns(
+      enclosingInsert: InsertIntoStatement): Option[StructType] = {
+    val target: StructType = 
getSchemaForTargetTable(enclosingInsert.table).getOrElse(return None)
+    val schema: StructType = 
StructType(target.fields.dropRight(enclosingInsert.partitionSpec.size))
     // Rearrange the columns in the result schema to match the order of the 
explicit column list,
     // if any.
-    val userSpecifiedCols: Seq[String] = enclosingInsert.get.userSpecifiedCols
+    val userSpecifiedCols: Seq[String] = enclosingInsert.userSpecifiedCols
     if (userSpecifiedCols.isEmpty) {
       return Some(schema)
     }
@@ -311,4 +363,30 @@ case class ResolveDefaultColumns(
     Some(StructType(userSpecifiedFields ++
       getStructFieldsForDefaultExpressions(nonUserSpecifiedFields)))
   }
+
+  /**
+   * Returns the schema for the target table of a DML command, looking into 
the catalog if needed.
+   */
+  private def getSchemaForTargetTable(table: LogicalPlan): Option[StructType] 
= {
+    // Lookup the relation from the catalog by name. This either succeeds or 
returns some "not
+    // found" error. In the latter cases, return out of this rule without 
changing anything and let
+    // the analyzer return a proper error message elsewhere.
+    val tableName: TableIdentifier = table match {
+      case r: UnresolvedRelation => TableIdentifier(r.name)
+      case r: UnresolvedCatalogRelation => r.tableMeta.identifier
+      case _ => return None
+    }
+    val lookup: LogicalPlan = try {
+      catalog.lookupRelation(tableName)
+    } catch {
+      case _: AnalysisException => return None
+    }
+    lookup match {
+      case SubqueryAlias(_, r: UnresolvedCatalogRelation) =>
+        Some(r.tableMeta.schema)
+      case SubqueryAlias(_, r: View) if r.isTempView =>
+        Some(r.schema)
+      case _ => None
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to