This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new eba31a8de3f [SPARK-41806][SQL] Use AppendData.byName for SQL INSERT INTO by name for DSV2 eba31a8de3f is described below commit eba31a8de3fb79f96255a0feb58db19842c9d16d Author: Allison Portis <allison.por...@databricks.com> AuthorDate: Fri Jan 6 10:42:16 2023 +0800 [SPARK-41806][SQL] Use AppendData.byName for SQL INSERT INTO by name for DSV2 ### What changes were proposed in this pull request? Use DSv2 AppendData.byName for INSERT INTO by name instead of reordering and converting to AppendData.byOrdinal ### Why are the changes needed? Currently for INSERT INTO by name we reorder the value list and convert it to INSERT INTO by ordinal. Since DSv2 logical nodes have the `isByName` flag we don't need to do this. The current approach is limiting in that - Users must provide the full list of table columns (this limits the functionality for features like generated columns see [SPARK-41290](https://issues.apache.org/jira/browse/SPARK-41290)) - It allows ambiguous queries such as `INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2')` where the user provides both the static partition column 'c' and the column 'c' in the column list. We should check that the static partition column is not in the column list. See the added test for more detailed example. ### Does this PR introduce _any_ user-facing change? For versions 3.3 and below: ```sql CREATE TABLE t(i STRING, c string) USING PARQUET PARTITIONED BY (c); INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2') SELECT * FROM t ``` ``` +---+---+ | i| c| +---+---+ | 2| 1| +---+---+ ``` For versions 3.4 and above: ```sql CREATE TABLE t(i STRING, c string) USING PARQUET PARTITIONED BY (c); INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2') ``` ``` AnalysisException: [STATIC_PARTITION_COLUMN_IN_COLUMN_LIST] Static partition column c is also specified in the column list. ``` ### How was this patch tested? Unit tests are added. Closes #39334 from allisonport-db/insert-into-by-name. Authored-by: Allison Portis <allison.por...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- core/src/main/resources/error/error-classes.json | 5 ++ .../spark/sql/catalyst/analysis/Analyzer.scala | 99 +++++++++++++++++++--- .../spark/sql/errors/QueryCompilationErrors.scala | 6 ++ .../org/apache/spark/sql/SQLInsertTestSuite.scala | 16 +++- .../spark/sql/connector/DataSourceV2SQLSuite.scala | 96 +++++++++++++++++++++ .../execution/command/PlanResolutionSuite.scala | 30 ++++++- 6 files changed, 239 insertions(+), 13 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 29cafdcc1b6..1d1952dce1b 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -1145,6 +1145,11 @@ "Star (*) is not allowed in a select list when GROUP BY an ordinal position is used." ] }, + "STATIC_PARTITION_COLUMN_IN_INSERT_COLUMN_LIST" : { + "message" : [ + "Static partition column <staticName> is also specified in the column list." + ] + }, "STREAM_FAILED" : { "message" : [ "Query [id = <id>, runId = <runId>] terminated with exception: <message>" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1ebbfb9a39a..8fff0d41add 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1291,28 +1291,92 @@ class Analyzer(override val catalogManager: CatalogManager) } } + /** Handle INSERT INTO for DSv2 */ object ResolveInsertInto extends Rule[LogicalPlan] { + + /** Add a project to use the table column names for INSERT INTO BY NAME */ + private def createProjectForByNameQuery(i: InsertIntoStatement): LogicalPlan = { + SchemaUtils.checkColumnNameDuplication(i.userSpecifiedCols, resolver) + + if (i.userSpecifiedCols.size != i.query.output.size) { + throw QueryCompilationErrors.writeTableWithMismatchedColumnsError( + i.userSpecifiedCols.size, i.query.output.size, i.query) + } + val projectByName = i.userSpecifiedCols.zip(i.query.output) + .map { case (userSpecifiedCol, queryOutputCol) => + val resolvedCol = i.table.resolve(Seq(userSpecifiedCol), resolver) + .getOrElse( + throw QueryCompilationErrors.unresolvedAttributeError( + "UNRESOLVED_COLUMN", userSpecifiedCol, i.table.output.map(_.name), i.origin)) + (queryOutputCol.dataType, resolvedCol.dataType) match { + case (input: StructType, expected: StructType) => + // Rename inner fields of the input column to pass the by-name INSERT analysis. + Alias(Cast(queryOutputCol, renameFieldsInStruct(input, expected)), resolvedCol.name)() + case _ => + Alias(queryOutputCol, resolvedCol.name)() + } + } + Project(projectByName, i.query) + } + + private def renameFieldsInStruct(input: StructType, expected: StructType): StructType = { + if (input.length == expected.length) { + val newFields = input.zip(expected).map { case (f1, f2) => + (f1.dataType, f2.dataType) match { + case (s1: StructType, s2: StructType) => + f1.copy(name = f2.name, dataType = renameFieldsInStruct(s1, s2)) + case _ => + f1.copy(name = f2.name) + } + } + StructType(newFields) + } else { + input + } + } + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( AlwaysProcess.fn, ruleId) { case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _) - if i.query.resolved && i.userSpecifiedCols.isEmpty => + if i.query.resolved => // ifPartitionNotExists is append with validation, but validation is not supported if (i.ifPartitionNotExists) { throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name) } + // Create a project if this is an INSERT INTO BY NAME query. + val projectByName = if (i.userSpecifiedCols.nonEmpty) { + Some(createProjectForByNameQuery(i)) + } else { + None + } + val isByName = projectByName.nonEmpty + val partCols = partitionColumnNames(r.table) validatePartitionSpec(partCols, i.partitionSpec) val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get).toMap - val query = addStaticPartitionColumns(r, i.query, staticPartitions) + val query = addStaticPartitionColumns(r, projectByName.getOrElse(i.query), staticPartitions, + isByName) if (!i.overwrite) { - AppendData.byPosition(r, query) + if (isByName) { + AppendData.byName(r, query) + } else { + AppendData.byPosition(r, query) + } } else if (conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) { - OverwritePartitionsDynamic.byPosition(r, query) + if (isByName) { + OverwritePartitionsDynamic.byName(r, query) + } else { + OverwritePartitionsDynamic.byPosition(r, query) + } } else { - OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions)) + if (isByName) { + OverwriteByExpression.byName(r, query, staticDeleteExpression(r, staticPartitions)) + } else { + OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions)) + } } } @@ -1343,7 +1407,8 @@ class Analyzer(override val catalogManager: CatalogManager) private def addStaticPartitionColumns( relation: DataSourceV2Relation, query: LogicalPlan, - staticPartitions: Map[String, String]): LogicalPlan = { + staticPartitions: Map[String, String], + isByName: Boolean): LogicalPlan = { if (staticPartitions.isEmpty) { query @@ -1352,13 +1417,23 @@ class Analyzer(override val catalogManager: CatalogManager) // add any static value as a literal column val withStaticPartitionValues = { // for each static name, find the column name it will replace and check for unknowns. - val outputNameToStaticName = staticPartitions.keySet.map(staticName => + val outputNameToStaticName = staticPartitions.keySet.map { staticName => + if (isByName) { + // If this is INSERT INTO BY NAME, the query output's names will be the user specified + // column names. We need to make sure the static partition column name doesn't appear + // there to catch the following ambiguous query: + // INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2') + if (query.output.find(col => conf.resolver(col.name, staticName)).nonEmpty) { + throw QueryCompilationErrors.staticPartitionInUserSpecifiedColumnsError(staticName) + } + } relation.output.find(col => conf.resolver(col.name, staticName)) match { case Some(attr) => attr.name -> staticName case _ => throw QueryCompilationErrors.missingStaticPartitionColumn(staticName) - }).toMap + } + }.toMap val queryColumns = query.output.iterator @@ -3646,11 +3721,15 @@ class Analyzer(override val catalogManager: CatalogManager) } } + /** + * A special rule to reorder columns for DSv1 when users specify a column list in INSERT INTO. + * DSv2 is handled by [[ResolveInsertInto]] separately. + */ object ResolveUserSpecifiedColumns extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( AlwaysProcess.fn, ruleId) { - case i: InsertIntoStatement if i.table.resolved && i.query.resolved && - i.userSpecifiedCols.nonEmpty => + case i: InsertIntoStatement if !i.table.isInstanceOf[DataSourceV2Relation] && + i.table.resolved && i.query.resolved && i.userSpecifiedCols.nonEmpty => val resolved = resolveUserSpecifiedColumns(i) val projection = addColumnListOnQuery(i.table.output, resolved, i.query) i.copy(userSpecifiedCols = Nil, query = projection) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 621f3e1ca90..f06444847ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -163,6 +163,12 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { messageParameters = Map("columnName" -> staticName)) } + def staticPartitionInUserSpecifiedColumnsError(staticName: String): Throwable = { + new AnalysisException( + errorClass = "STATIC_PARTITION_COLUMN_IN_INSERT_COLUMN_LIST", + messageParameters = Map("staticName" -> staticName)) + } + def nestedGeneratorError(trimmedNestedGenerator: Expression): Throwable = { new AnalysisException(errorClass = "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS", messageParameters = Map("expression" -> toSQLExpr(trimmedNestedGenerator))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala index f620c0b4c86..051ac0f3141 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala @@ -201,8 +201,8 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils { } } - test("insert with column list - mismatched target table out size after rewritten query") { - val v2Msg = "expected 2 columns but found" + test("insert with column list - missing columns") { + val v2Msg = "Cannot write incompatible data to table 'testcat.t1'" val cols = Seq("c1", "c2", "c3", "c4") withTable("t1") { @@ -369,4 +369,16 @@ class DSV2SQLInsertTestSuite extends SQLInsertTestSuite with SharedSparkSession .set("spark.sql.catalog.testcat", classOf[InMemoryPartitionTableCatalog].getName) .set(SQLConf.DEFAULT_CATALOG.key, "testcat") } + + test("static partition column name should not be used in the column list") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c string) USING PARQUET PARTITIONED BY (c)") + checkError( + exception = intercept[AnalysisException] { + sql("INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2')") + }, + errorClass = "STATIC_PARTITION_COLUMN_IN_INSERT_COLUMN_LIST", + parameters = Map("staticName" -> "c")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 03b42a760ea..a4b7f762dba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1078,6 +1078,102 @@ class DataSourceV2SQLSuiteV1Filter } } + test("insertInto: append by name") { + import testImplicits._ + val t1 = "tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + sql(s"INSERT INTO $t1(id, data) VALUES(1L, 'a')") + // Can be in a different order + sql(s"INSERT INTO $t1(data, id) VALUES('b', 2L)") + // Can be casted automatically + sql(s"INSERT INTO $t1(data, id) VALUES('c', 3)") + verifyTable(t1, df) + // Missing columns + assert(intercept[AnalysisException] { + sql(s"INSERT INTO $t1(data) VALUES(4)") + }.getMessage.contains("Cannot find data for output column 'id'")) + // Duplicate columns + checkError( + exception = intercept[AnalysisException] { + sql(s"INSERT INTO $t1(data, data) VALUES(5)") + }, + errorClass = "COLUMN_ALREADY_EXISTS", + parameters = Map("columnName" -> "`data`") + ) + } + } + + test("insertInto: overwrite by name") { + import testImplicits._ + val t1 = "tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") + sql(s"INSERT OVERWRITE $t1(id, data) VALUES(1L, 'a')") + verifyTable(t1, Seq((1L, "a")).toDF("id", "data")) + // Can be in a different order + sql(s"INSERT OVERWRITE $t1(data, id) VALUES('b', 2L)") + verifyTable(t1, Seq((2L, "b")).toDF("id", "data")) + // Can be casted automatically + sql(s"INSERT OVERWRITE $t1(data, id) VALUES('c', 3)") + verifyTable(t1, Seq((3L, "c")).toDF("id", "data")) + // Missing columns + assert(intercept[AnalysisException] { + sql(s"INSERT OVERWRITE $t1(data) VALUES(4)") + }.getMessage.contains("Cannot find data for output column 'id'")) + // Duplicate columns + checkError( + exception = intercept[AnalysisException] { + sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)") + }, + errorClass = "COLUMN_ALREADY_EXISTS", + parameters = Map("columnName" -> "`data`") + ) + } + } + + dynamicOverwriteTest("insertInto: dynamic overwrite by name") { + import testImplicits._ + val t1 = "tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, data2 string) " + + s"USING $v2Format PARTITIONED BY (id)") + sql(s"INSERT OVERWRITE $t1(id, data, data2) VALUES(1L, 'a', 'b')") + verifyTable(t1, Seq((1L, "a", "b")).toDF("id", "data", "data2")) + // Can be in a different order + sql(s"INSERT OVERWRITE $t1(data, data2, id) VALUES('b', 'd', 2L)") + verifyTable(t1, Seq((1L, "a", "b"), (2L, "b", "d")).toDF("id", "data", "data2")) + // Can be casted automatically + sql(s"INSERT OVERWRITE $t1(data, data2, id) VALUES('c', 'e', 1)") + verifyTable(t1, Seq((1L, "c", "e"), (2L, "b", "d")).toDF("id", "data", "data2")) + // Missing columns + assert(intercept[AnalysisException] { + sql(s"INSERT OVERWRITE $t1(data, id) VALUES('a', 4)") + }.getMessage.contains("Cannot find data for output column 'data2'")) + // Duplicate columns + checkError( + exception = intercept[AnalysisException] { + sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)") + }, + errorClass = "COLUMN_ALREADY_EXISTS", + parameters = Map("columnName" -> "`data`") + ) + } + } + + test("insertInto: static partition column name should not be used in the column list") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c string) USING $v2Format PARTITIONED BY (c)") + checkError( + exception = intercept[AnalysisException] { + sql("INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2')") + }, + errorClass = "STATIC_PARTITION_COLUMN_IN_INSERT_COLUMN_LIST", + parameters = Map("staticName" -> "c")) + } + } + test("ShowViews: using v1 catalog, db name with multipartIdentifier ('a.b') is not allowed.") { checkError( exception = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index d78317c81a2..00d8101df83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, EqualTo, EvalMode, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} -import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTable, CreateTableAsSelect, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, InsertIntoStatement, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable} +import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTable, CreateTableAsSelect, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, InsertIntoStatement, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, OverwriteByExpression, OverwritePartitionsDynamic, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns import org.apache.spark.sql.connector.FakeV2Provider @@ -42,6 +42,7 @@ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode} import org.apache.spark.sql.sources.SimpleScanSource import org.apache.spark.sql.types.{BooleanType, CharType, DoubleType, IntegerType, LongType, MetadataBuilder, StringType, StructField, StructType} @@ -1237,6 +1238,33 @@ class PlanResolutionSuite extends AnalysisTest { } } + test("InsertIntoStatement byName") { + val tblName = "testcat.tab1" + val insertSql = s"INSERT INTO $tblName(i, s) VALUES (3, 'a')" + val insertParsed = parseAndResolve(insertSql) + val overwriteSql = s"INSERT OVERWRITE $tblName(i, s) VALUES (3, 'a')" + val overwriteParsed = parseAndResolve(overwriteSql) + insertParsed match { + case AppendData(_: DataSourceV2Relation, _, _, isByName, _, _) => + assert(isByName) + case _ => fail("Expected AppendData, but got:\n" + insertParsed.treeString) + } + overwriteParsed match { + case OverwriteByExpression(_: DataSourceV2Relation, _, _, _, isByName, _, _) => + assert(isByName) + case _ => fail("Expected OverwriteByExpression, but got:\n" + overwriteParsed.treeString) + } + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val dynamicOverwriteParsed = parseAndResolve(overwriteSql) + dynamicOverwriteParsed match { + case OverwritePartitionsDynamic(_: DataSourceV2Relation, _, _, isByName, _) => + assert(isByName) + case _ => + fail("Expected OverwriteByExpression, but got:\n" + dynamicOverwriteParsed.treeString) + } + } + } + test("alter table: alter column") { Seq("v1Table" -> true, "v2Table" -> false, "testcat.tab" -> false).foreach { case (tblName, useV1Command) => --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org