This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new 4ccd530f639 [SPARK-38085][SQL] DataSource V2: Handle DELETE commands for group-based sources 4ccd530f639 is described below commit 4ccd530f639e3652b7aad7c8bcfa379847dc2b68 Author: Anton Okolnychyi <aokolnyc...@apple.com> AuthorDate: Wed Apr 13 13:47:00 2022 +0800 [SPARK-38085][SQL] DataSource V2: Handle DELETE commands for group-based sources This PR contains changes to rewrite DELETE operations for V2 data sources that can replace groups of data (e.g. files, partitions). These changes are needed to support row-level operations in Spark per SPIP SPARK-35801. No. This PR comes with tests. Closes #35395 from aokolnychyi/spark-38085. Authored-by: Anton Okolnychyi <aokolnyc...@apple.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 5a92eccd514b7bc0513feaecb041aee2f8cd5a24) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 1 + .../catalyst/analysis/RewriteDeleteFromTable.scala | 89 +++ .../catalyst/analysis/RewriteRowLevelCommand.scala | 71 +++ .../ReplaceNullWithFalseInPredicate.scala | 3 +- .../SimplifyConditionalsInPredicate.scala | 1 + .../spark/sql/catalyst/planning/patterns.scala | 51 ++ .../sql/catalyst/plans/logical/v2Commands.scala | 92 ++- .../write/RowLevelOperationInfoImpl.scala | 25 + .../connector/write/RowLevelOperationTable.scala | 51 ++ .../spark/sql/errors/QueryCompilationErrors.scala | 4 + .../datasources/v2/DataSourceV2Implicits.scala | 10 + .../catalog/InMemoryRowLevelOperationTable.scala | 96 ++++ .../InMemoryRowLevelOperationTableCatalog.scala | 46 ++ .../sql/connector/catalog/InMemoryTable.scala | 22 +- .../spark/sql/execution/SparkOptimizer.scala | 7 +- .../datasources/v2/DataSourceV2Strategy.scala | 22 +- .../GroupBasedRowLevelOperationScanPlanning.scala | 83 +++ .../v2/OptimizeMetadataOnlyDeleteFromTable.scala | 84 +++ .../execution/datasources/v2/PushDownUtils.scala | 2 +- .../sql/execution/datasources/v2/V2Writes.scala | 24 +- .../datasources/v2/WriteToDataSourceV2Exec.scala | 15 + .../spark/sql/connector/DeleteFromTableSuite.scala | 629 +++++++++++++++++++++ .../execution/command/PlanResolutionSuite.scala | 4 +- 23 files changed, 1407 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6b44483ab1d..9fdc466b425 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -318,6 +318,7 @@ class Analyzer(override val catalogManager: CatalogManager) ResolveRandomSeed :: ResolveBinaryArithmetic :: ResolveUnion :: + RewriteDeleteFromTable :: typeCoercionRules ++ Seq(ResolveWithCTE) ++ extendedResolutionRules : _*), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala new file mode 100644 index 00000000000..85af999902e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{EqualNullSafe, Expression, Not} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter, LogicalPlan, ReplaceData} +import org.apache.spark.sql.connector.catalog.{SupportsDelete, SupportsRowLevelOperations, TruncatableTable} +import org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE +import org.apache.spark.sql.connector.write.RowLevelOperationTable +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * A rule that rewrites DELETE operations using plans that operate on individual or groups of rows. + * + * If a table implements [[SupportsDelete]] and [[SupportsRowLevelOperations]], this rule will + * still rewrite the DELETE operation but the optimizer will check whether this particular DELETE + * statement can be handled by simply passing delete filters to the connector. If so, the optimizer + * will discard the rewritten plan and will allow the data source to delete using filters. + */ +object RewriteDeleteFromTable extends RewriteRowLevelCommand { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case d @ DeleteFromTable(aliasedTable, cond) if d.resolved => + EliminateSubqueryAliases(aliasedTable) match { + case DataSourceV2Relation(_: TruncatableTable, _, _, _, _) if cond == TrueLiteral => + // don't rewrite as the table supports truncation + d + + case r @ DataSourceV2Relation(t: SupportsRowLevelOperations, _, _, _, _) => + val table = buildOperationTable(t, DELETE, CaseInsensitiveStringMap.empty()) + buildReplaceDataPlan(r, table, cond) + + case DataSourceV2Relation(_: SupportsDelete, _, _, _, _) => + // don't rewrite as the table supports deletes only with filters + d + + case DataSourceV2Relation(t, _, _, _, _) => + throw QueryCompilationErrors.tableDoesNotSupportDeletesError(t) + + case _ => + d + } + } + + // build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions) + private def buildReplaceDataPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + cond: Expression): ReplaceData = { + + // resolve all required metadata attrs that may be used for grouping data on write + // for instance, JDBC data source may cluster data by shard/host before writing + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + + // construct a read relation and include all required metadata columns + val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + + // construct a plan that contains unmatched rows in matched groups that must be carried over + // such rows do not match the condition but have to be copied over as the source can replace + // only groups of rows (e.g. if a source supports replacing files, unmatched rows in matched + // files must be carried over) + // it is safe to negate the condition here as the predicate pushdown for group-based row-level + // operations is handled in a special way + val remainingRowsFilter = Not(EqualNullSafe(cond, TrueLiteral)) + val remainingRowsPlan = Filter(remainingRowsFilter, readRelation) + + // build a plan to replace read groups in the table + val writeRelation = relation.copy(table = operationTable) + ReplaceData(writeRelation, cond, remainingRowsPlan, relation) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala new file mode 100644 index 00000000000..bf8c3e27f4d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.connector.write.{RowLevelOperation, RowLevelOperationInfoImpl, RowLevelOperationTable} +import org.apache.spark.sql.connector.write.RowLevelOperation.Command +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +trait RewriteRowLevelCommand extends Rule[LogicalPlan] { + + protected def buildOperationTable( + table: SupportsRowLevelOperations, + command: Command, + options: CaseInsensitiveStringMap): RowLevelOperationTable = { + val info = RowLevelOperationInfoImpl(command, options) + val operation = table.newRowLevelOperationBuilder(info).build() + RowLevelOperationTable(table, operation) + } + + protected def buildRelationWithAttrs( + relation: DataSourceV2Relation, + table: RowLevelOperationTable, + metadataAttrs: Seq[AttributeReference]): DataSourceV2Relation = { + + val attrs = dedupAttrs(relation.output ++ metadataAttrs) + relation.copy(table = table, output = attrs) + } + + protected def dedupAttrs(attrs: Seq[AttributeReference]): Seq[AttributeReference] = { + val exprIds = mutable.Set.empty[ExprId] + attrs.flatMap { attr => + if (exprIds.contains(attr.exprId)) { + None + } else { + exprIds += attr.exprId + Some(attr) + } + } + } + + protected def resolveRequiredMetadataAttrs( + relation: DataSourceV2Relation, + operation: RowLevelOperation): Seq[AttributeReference] = { + + V2ExpressionUtils.resolveRefs[AttributeReference]( + operation.requiredMetadataAttributes, + relation) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 9ec498aa14e..d060a8be5da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, EqualNullSafe, Expression, If, In, InSet, LambdaFunction, Literal, MapFilter, Not, Or} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} -import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, DeleteFromTable, Filter, InsertAction, InsertStarAction, Join, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateStarAction, UpdateTable} +import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, DeleteFromTable, Filter, InsertAction, InsertStarAction, Join, LogicalPlan, MergeAction, MergeIntoTable, ReplaceData, UpdateAction, UpdateStarAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{INSET, NULL_LITERAL, TRUE_OR_FALSE_LITERAL} import org.apache.spark.sql.types.BooleanType @@ -54,6 +54,7 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { _.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL, INSET), ruleId) { case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond))) + case rd @ ReplaceData(_, cond, _, _, _) => rd.copy(condition = replaceNullWithFalse(cond)) case d @ DeleteFromTable(_, cond) => d.copy(condition = replaceNullWithFalse(cond)) case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond))) case m @ MergeIntoTable(_, _, mergeCond, matchedActions, notMatchedActions) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalsInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalsInPredicate.scala index e1972b997c2..34773b24cac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalsInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalsInPredicate.scala @@ -48,6 +48,7 @@ object SimplifyConditionalsInPredicate extends Rule[LogicalPlan] { _.containsAnyPattern(CASE_WHEN, IF), ruleId) { case f @ Filter(cond, _) => f.copy(condition = simplifyConditional(cond)) case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(simplifyConditional(cond))) + case rd @ ReplaceData(_, cond, _, _, _) => rd.copy(condition = simplifyConditional(cond)) case d @ DeleteFromTable(_, cond) => d.copy(condition = simplifyConditional(cond)) case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(simplifyConditional(cond))) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 8c41ab2797b..382909d6d6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -18,12 +18,15 @@ package org.apache.spark.sql.catalyst.planning import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.JoinSelectionHelper import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.internal.SQLConf trait OperationHelper extends AliasHelper with PredicateHelper { @@ -388,3 +391,51 @@ object ExtractSingleColumnNullAwareAntiJoin extends JoinSelectionHelper with Pre case _ => None } } + +/** + * An extractor for row-level commands such as DELETE, UPDATE, MERGE that were rewritten using plans + * that operate on groups of rows. + * + * This class extracts the following entities: + * - the group-based rewrite plan; + * - the condition that defines matching groups; + * - the read relation that can be either [[DataSourceV2Relation]] or [[DataSourceV2ScanRelation]] + * depending on whether the planning has already happened; + */ +object GroupBasedRowLevelOperation { + type ReturnType = (ReplaceData, Expression, LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case rd @ ReplaceData(DataSourceV2Relation(table, _, _, _, _), cond, query, _, _) => + val readRelation = findReadRelation(table, query) + readRelation.map((rd, cond, _)) + + case _ => + None + } + + private def findReadRelation( + table: Table, + plan: LogicalPlan): Option[LogicalPlan] = { + + val readRelations = plan.collect { + case r: DataSourceV2Relation if r.table eq table => r + case r: DataSourceV2ScanRelation if r.relation.table eq table => r + } + + // in some cases, the optimizer replaces the v2 read relation with a local relation + // for example, there is no reason to query the table if the condition is always false + // that's why it is valid not to find the corresponding v2 read relation + + readRelations match { + case relations if relations.isEmpty => + None + + case Seq(relation) => + Some(relation) + + case relations => + throw new AnalysisException(s"Expected only one row-level read relation: $relations") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index b2ca34668a6..b1b8843aa33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, FieldName, NamedRelation, PartitionSpec, ResolvedDBObjectName, UnresolvedException} +import org.apache.spark.sql.{sources, AnalysisException} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedDBObjectName, UnresolvedException} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.FunctionResource -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, Unevaluable} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.connector.write.Write +import org.apache.spark.sql.connector.write.{RowLevelOperation, RowLevelOperationTable, Write} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{BooleanType, DataType, MetadataBuilder, StringType, StructType} /** @@ -176,6 +178,80 @@ object OverwritePartitionsDynamic { } } +trait RowLevelWrite extends V2WriteCommand with SupportsSubquery { + def operation: RowLevelOperation + def condition: Expression + def originalTable: NamedRelation +} + +/** + * Replace groups of data in an existing table during a row-level operation. + * + * This node is constructed in rules that rewrite DELETE, UPDATE, MERGE operations for data sources + * that can replace groups of data (e.g. files, partitions). + * + * @param table a plan that references a row-level operation table + * @param condition a condition that defines matching groups + * @param query a query with records that should replace the records that were read + * @param originalTable a plan for the original table for which the row-level command was triggered + * @param write a logical write, if already constructed + */ +case class ReplaceData( + table: NamedRelation, + condition: Expression, + query: LogicalPlan, + originalTable: NamedRelation, + write: Option[Write] = None) extends RowLevelWrite { + + override val isByName: Boolean = false + override val stringArgs: Iterator[Any] = Iterator(table, query, write) + + override lazy val references: AttributeSet = query.outputSet + + lazy val operation: RowLevelOperation = { + EliminateSubqueryAliases(table) match { + case DataSourceV2Relation(RowLevelOperationTable(_, operation), _, _, _, _) => + operation + case _ => + throw new AnalysisException(s"Cannot retrieve row-level operation from $table") + } + } + + // the incoming query may include metadata columns + lazy val dataInput: Seq[Attribute] = { + query.output.filter { + case MetadataAttribute(_) => false + case _ => true + } + } + + override def outputResolved: Boolean = { + assert(table.resolved && query.resolved, + "`outputResolved` can only be called when `table` and `query` are both resolved.") + + // take into account only incoming data columns and ignore metadata columns in the query + // they will be discarded after the logical write is built in the optimizer + // metadata columns may be needed to request a correct distribution or ordering + // but are not passed back to the data source during writes + + table.skipSchemaResolution || (dataInput.size == table.output.size && + dataInput.zip(table.output).forall { case (inAttr, outAttr) => + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) + // names and types must match, nullability must be compatible + inAttr.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(inAttr.dataType, outType) && + (outAttr.nullable || !inAttr.nullable) + }) + } + + override def withNewQuery(newQuery: LogicalPlan): ReplaceData = copy(query = newQuery) + + override def withNewTable(newTable: NamedRelation): ReplaceData = copy(table = newTable) + + override protected def withNewChildInternal(newChild: LogicalPlan): ReplaceData = { + copy(query = newChild) + } +} /** A trait used for logical plan nodes that create or replace V2 table definitions. */ trait V2CreateTablePlan extends LogicalPlan { @@ -457,6 +533,16 @@ case class DeleteFromTable( copy(table = newChild) } +/** + * The logical plan of the DELETE FROM command that can be executed using data source filters. + * + * As opposed to [[DeleteFromTable]], this node represents a DELETE operation where the condition + * was converted into filters and the data source reported that it can handle all of them. + */ +case class DeleteFromTableWithFilters( + table: LogicalPlan, + condition: Seq[sources.Filter]) extends LeafCommand + /** * The logical plan of the UPDATE TABLE command. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala new file mode 100644 index 00000000000..9d499cdef36 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write + +import org.apache.spark.sql.connector.write.RowLevelOperation.Command +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +private[sql] case class RowLevelOperationInfoImpl( + command: Command, + options: CaseInsensitiveStringMap) extends RowLevelOperationInfo diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationTable.scala new file mode 100644 index 00000000000..d1f7ba000c6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationTable.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write + +import java.util + +import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsRowLevelOperations, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * An internal v2 table implementation that wraps the original table and a logical row-level + * operation for DELETE, UPDATE, MERGE commands that require rewriting data. + * + * The purpose of this table is to make the existing scan and write planning rules work + * with commands that require coordination between the scan and the write (so that the write + * knows what to replace). + */ +private[sql] case class RowLevelOperationTable( + table: Table with SupportsRowLevelOperations, + operation: RowLevelOperation) extends Table with SupportsRead with SupportsWrite { + + override def name: String = table.name + override def schema: StructType = table.schema + override def capabilities: util.Set[TableCapability] = table.capabilities + override def toString: String = table.toString + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + operation.newScanBuilder(options) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + operation.newWriteBuilder(info) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 57ed7da7b20..0532a953ef4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -926,6 +926,10 @@ object QueryCompilationErrors { tableDoesNotSupportError("atomic partition management", table) } + def tableIsNotRowLevelOperationTableError(table: Table): Throwable = { + throw new AnalysisException(s"Table ${table.name} is not a row-level operation table") + } + def cannotRenameTableWithAlterViewError(): Throwable = { new AnalysisException( "Cannot rename a table with ALTER VIEW. Please use ALTER TABLE instead.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala index efd3ffebf5c..16d5a9cc70d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{PartitionSpec, ResolvedPartitionS import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsAtomicPartitionManagement, SupportsDelete, SupportsPartitionManagement, SupportsRead, SupportsWrite, Table, TableCapability, TruncatableTable} +import org.apache.spark.sql.connector.write.RowLevelOperationTable import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -82,6 +83,15 @@ object DataSourceV2Implicits { } } + def asRowLevelOperationTable: RowLevelOperationTable = { + table match { + case rowLevelOperationTable: RowLevelOperationTable => + rowLevelOperationTable + case _ => + throw QueryCompilationErrors.tableIsNotRowLevelOperationTableError(table) + } + } + def supports(capability: TableCapability): Boolean = table.capabilities.contains(capability) def supportsAny(capabilities: TableCapability*): Boolean = capabilities.exists(supports) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala new file mode 100644 index 00000000000..cb061602ec1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util + +import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} +import org.apache.spark.sql.connector.expressions.{FieldReference, LogicalExpressions, NamedReference, SortDirection, SortOrder, Transform} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} +import org.apache.spark.sql.connector.write.{BatchWrite, LogicalWriteInfo, RequiresDistributionAndOrdering, RowLevelOperation, RowLevelOperationBuilder, RowLevelOperationInfo, Write, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.connector.write.RowLevelOperation.Command +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class InMemoryRowLevelOperationTable( + name: String, + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]) + extends InMemoryTable(name, schema, partitioning, properties) with SupportsRowLevelOperations { + + override def newRowLevelOperationBuilder( + info: RowLevelOperationInfo): RowLevelOperationBuilder = { + () => PartitionBasedOperation(info.command) + } + + case class PartitionBasedOperation(command: Command) extends RowLevelOperation { + private final val PARTITION_COLUMN_REF = FieldReference(PartitionKeyColumn.name) + + var configuredScan: InMemoryBatchScan = _ + + override def requiredMetadataAttributes(): Array[NamedReference] = { + Array(PARTITION_COLUMN_REF) + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new InMemoryScanBuilder(schema) { + override def build: Scan = { + val scan = super.build() + configuredScan = scan.asInstanceOf[InMemoryBatchScan] + scan + } + } + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = new WriteBuilder { + + override def build(): Write = new Write with RequiresDistributionAndOrdering { + override def requiredDistribution(): Distribution = { + Distributions.clustered(Array(PARTITION_COLUMN_REF)) + } + + override def requiredOrdering(): Array[SortOrder] = { + Array[SortOrder]( + LogicalExpressions.sort( + PARTITION_COLUMN_REF, + SortDirection.ASCENDING, + SortDirection.ASCENDING.defaultNullOrdering()) + ) + } + + override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan) + + override def description(): String = "InMemoryWrite" + } + } + + override def description(): String = "InMemoryPartitionReplaceOperation" + } + + private case class PartitionBasedReplaceData(scan: InMemoryBatchScan) extends TestBatchWrite { + + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + val newData = messages.map(_.asInstanceOf[BufferedRows]) + val readRows = scan.data.flatMap(_.asInstanceOf[BufferedRows].rows) + val readPartitions = readRows.map(r => getKey(r, schema)) + dataMap --= readPartitions + withData(newData, schema) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala new file mode 100644 index 00000000000..2d9a9f04785 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util + +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.types.StructType + +class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog { + import CatalogV2Implicits._ + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident) + } + + InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) + + val tableName = s"$name.${ident.quoted}" + val table = new InMemoryRowLevelOperationTable(tableName, schema, partitions, properties) + tables.put(ident, table) + namespaces.putIfAbsent(ident.namespace.toList, Map()) + table + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index a762b0f8783..beed9111a30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -56,7 +56,7 @@ class InMemoryTable( extends Table with SupportsRead with SupportsWrite with SupportsDelete with SupportsMetadataColumns { - private object PartitionKeyColumn extends MetadataColumn { + protected object PartitionKeyColumn extends MetadataColumn { override def name: String = "_partition" override def dataType: DataType = StringType override def comment: String = "Partition key used to store the row" @@ -104,7 +104,11 @@ class InMemoryTable( private val UTC = ZoneId.of("UTC") private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate - private def getKey(row: InternalRow): Seq[Any] = { + protected def getKey(row: InternalRow): Seq[Any] = { + getKey(row, schema) + } + + protected def getKey(row: InternalRow, rowSchema: StructType): Seq[Any] = { @scala.annotation.tailrec def extractor( fieldNames: Array[String], @@ -124,7 +128,7 @@ class InMemoryTable( } } - val cleanedSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) + val cleanedSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(rowSchema) partitioning.map { case IdentityTransform(ref) => extractor(ref.fieldNames, cleanedSchema, row)._1 @@ -219,9 +223,15 @@ class InMemoryTable( dataMap(key).clear() } - def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { + def withData(data: Array[BufferedRows]): InMemoryTable = { + withData(data, schema) + } + + def withData( + data: Array[BufferedRows], + writeSchema: StructType): InMemoryTable = dataMap.synchronized { data.foreach(_.rows.foreach { row => - val key = getKey(row) + val key = getKey(row, writeSchema) dataMap += dataMap.get(key) .map(key -> _.withRow(row)) .getOrElse(key -> new BufferedRows(key).withRow(row)) @@ -372,7 +382,7 @@ class InMemoryTable( } } - private abstract class TestBatchWrite extends BatchWrite { + protected abstract class TestBatchWrite extends BatchWrite { override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { BufferedRowsWriterFactory } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index bfe4bd29241..8c134363af1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.SchemaPruning -import org.apache.spark.sql.execution.datasources.v2.{V2ScanPartitioning, V2ScanRelationPushDown, V2Writes} +import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioning, V2ScanRelationPushDown, V2Writes} import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning} import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs} @@ -38,11 +38,15 @@ class SparkOptimizer( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = // TODO: move SchemaPruning into catalyst Seq(SchemaPruning) :+ + GroupBasedRowLevelOperationScanPlanning :+ V2ScanRelationPushDown :+ V2ScanPartitioning :+ V2Writes :+ PruneFileSourcePartitions + override def preCBORules: Seq[Rule[LogicalPlan]] = + OptimizeMetadataOnlyDeleteFromTable :: Nil + override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("PartitionPruning", Once, @@ -78,6 +82,7 @@ class SparkOptimizer( ExtractPythonUDFFromJoinCondition.ruleName :+ ExtractPythonUDFFromAggregate.ruleName :+ ExtractGroupingPythonUDFFromAggregate.ruleName :+ ExtractPythonUDFs.ruleName :+ + GroupBasedRowLevelOperationScanPlanning.ruleName :+ V2ScanRelationPushDown.ruleName :+ V2ScanPartitioning.ruleName :+ V2Writes.ruleName diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 45540fb4a11..95418027187 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -25,10 +25,11 @@ import org.apache.spark.sql.catalyst.analysis.{ResolvedDBObjectName, ResolvedNam import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, Not, Or, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{toPrettySQL, V2ExpressionBuilder} -import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDelete, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable} import org.apache.spark.sql.connector.catalog.index.SupportsIndex import org.apache.spark.sql.connector.expressions.{FieldReference} import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} @@ -254,6 +255,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _, _, Some(write)) => OverwritePartitionsDynamicExec(planLater(query), refreshCache(r), write) :: Nil + case DeleteFromTableWithFilters(r: DataSourceV2Relation, filters) => + DeleteFromTableExec(r.table.asDeletable, filters.toArray, refreshCache(r)) :: Nil + case DeleteFromTable(relation, condition) => relation match { case DataSourceV2ScanRelation(r, _, output, _) => @@ -269,15 +273,25 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat throw QueryCompilationErrors.cannotTranslateExpressionToSourceFilterError(f)) }).toArray - if (!table.asDeletable.canDeleteWhere(filters)) { - throw QueryCompilationErrors.cannotDeleteTableWhereFiltersError(table, filters) + table match { + case t: SupportsDelete if t.canDeleteWhere(filters) => + DeleteFromTableExec(t, filters, refreshCache(r)) :: Nil + case t: SupportsDelete => + throw QueryCompilationErrors.cannotDeleteTableWhereFiltersError(t, filters) + case t: TruncatableTable if condition == TrueLiteral => + TruncateTableExec(t, refreshCache(r)) :: Nil + case _ => + throw QueryCompilationErrors.tableDoesNotSupportDeletesError(table) } - DeleteFromTableExec(table.asDeletable, filters, refreshCache(r)) :: Nil case _ => throw QueryCompilationErrors.deleteOnlySupportedWithV2TablesError() } + case ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, Some(write)) => + // use the original relation to refresh the cache + ReplaceDataExec(planLater(query), refreshCache(r), write) :: Nil + case WriteToContinuousDataSource(writer, query, customMetrics) => WriteToContinuousDataSourceExec(writer, planLater(query), customMetrics) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala new file mode 100644 index 00000000000..48dee3f652c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.planning.GroupBasedRowLevelOperation +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReplaceData} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.expressions.filter.{Predicate => V2Filter} +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.Filter + +/** + * A rule that builds scans for group-based row-level operations. + * + * Note this rule must be run before [[V2ScanRelationPushDown]] as scans for group-based + * row-level operations must be planned in a special way. + */ +object GroupBasedRowLevelOperationScanPlanning extends Rule[LogicalPlan] with PredicateHelper { + + import DataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + // push down the filter from the command condition instead of the filter in the rewrite plan, + // which is negated for data sources that only support replacing groups of data (e.g. files) + case GroupBasedRowLevelOperation(rd: ReplaceData, cond, relation: DataSourceV2Relation) => + val table = relation.table.asRowLevelOperationTable + val scanBuilder = table.newScanBuilder(relation.options) + + val (pushedFilters, remainingFilters) = pushFilters(cond, relation.output, scanBuilder) + val pushedFiltersStr = if (pushedFilters.isLeft) { + pushedFilters.left.get.mkString(", ") + } else { + pushedFilters.right.get.mkString(", ") + } + + val (scan, output) = PushDownUtils.pruneColumns(scanBuilder, relation, relation.output, Nil) + + logInfo( + s""" + |Pushing operators to ${relation.name} + |Pushed filters: $pushedFiltersStr + |Filters that were not pushed: ${remainingFilters.mkString(", ")} + |Output: ${output.mkString(", ")} + """.stripMargin) + + // replace DataSourceV2Relation with DataSourceV2ScanRelation for the row operation table + rd transform { + case r: DataSourceV2Relation if r eq relation => + DataSourceV2ScanRelation(r, scan, PushDownUtils.toOutputAttrs(scan.readSchema(), r)) + } + } + + private def pushFilters( + cond: Expression, + tableAttrs: Seq[AttributeReference], + scanBuilder: ScanBuilder): (Either[Seq[Filter], Seq[V2Filter]], Seq[Expression]) = { + + val tableAttrSet = AttributeSet(tableAttrs) + val filters = splitConjunctivePredicates(cond).filter(_.references.subsetOf(tableAttrSet)) + val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, tableAttrs) + val (_, normalizedFiltersWithoutSubquery) = + normalizedFilters.partition(SubqueryExpression.hasSubquery) + + PushDownUtils.pushFilters(scanBuilder, normalizedFiltersWithoutSubquery) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala new file mode 100644 index 00000000000..bc45dbe9fef --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.{Expression, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, DeleteFromTableWithFilters, LogicalPlan, ReplaceData, RowLevelWrite} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{SupportsDelete, TruncatableTable} +import org.apache.spark.sql.connector.write.RowLevelOperation +import org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources + +/** + * A rule that replaces a rewritten DELETE operation with a delete using filters if the data source + * can handle this DELETE command without executing the plan that operates on individual or groups + * of rows. + * + * Note this rule must be run after expression optimization but before scan planning. + */ +object OptimizeMetadataOnlyDeleteFromTable extends Rule[LogicalPlan] with PredicateHelper { + + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case RewrittenRowLevelCommand(rowLevelPlan, DELETE, cond, relation: DataSourceV2Relation) => + relation.table match { + case table: SupportsDelete if !SubqueryExpression.hasSubquery(cond) => + val predicates = splitConjunctivePredicates(cond) + val normalizedPredicates = DataSourceStrategy.normalizeExprs(predicates, relation.output) + val filters = toDataSourceFilters(normalizedPredicates) + val allPredicatesTranslated = normalizedPredicates.size == filters.length + if (allPredicatesTranslated && table.canDeleteWhere(filters)) { + logDebug(s"Switching to delete with filters: ${filters.mkString("[", ", ", "]")}") + DeleteFromTableWithFilters(relation, filters) + } else { + rowLevelPlan + } + + case _: TruncatableTable if cond == TrueLiteral => + DeleteFromTable(relation, cond) + + case _ => + rowLevelPlan + } + } + + private def toDataSourceFilters(predicates: Seq[Expression]): Array[sources.Filter] = { + predicates.flatMap { p => + val filter = DataSourceStrategy.translateFilter(p, supportNestedPredicatePushdown = true) + if (filter.isEmpty) { + logDebug(s"Cannot translate expression to data source filter: $p") + } + filter + }.toArray + } + + private object RewrittenRowLevelCommand { + type ReturnType = (RowLevelWrite, RowLevelOperation.Command, Expression, LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case rd @ ReplaceData(_, cond, _, originalTable, _) => + val command = rd.operation.command + Some(rd, command, cond, originalTable) + + case _ => + None + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 862189ed3af..8ac91e02579 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -187,7 +187,7 @@ object PushDownUtils extends PredicateHelper { } } - private def toOutputAttrs( + def toOutputAttrs( schema: StructType, relation: DataSourceV2Relation): Seq[AttributeReference] = { val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index 38f741532d7..2fd1d52fd98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.UUID import org.apache.spark.sql.catalyst.expressions.PredicateHelper -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project, ReplaceData} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, WriteT import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend import org.apache.spark.sql.sources.{AlwaysTrue, Filter} import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType /** * A rule that constructs logical writes. @@ -41,7 +42,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case a @ AppendData(r: DataSourceV2Relation, query, options, _, None) => - val writeBuilder = newWriteBuilder(r.table, query, options) + val writeBuilder = newWriteBuilder(r.table, options, query.schema) val write = writeBuilder.build() val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf) a.copy(write = Some(write), query = newQuery) @@ -57,7 +58,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { }.toArray val table = r.table - val writeBuilder = newWriteBuilder(table, query, options) + val writeBuilder = newWriteBuilder(table, options, query.schema) val write = writeBuilder match { case builder: SupportsTruncate if isTruncate(filters) => builder.truncate().build() @@ -72,7 +73,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) => val table = r.table - val writeBuilder = newWriteBuilder(table, query, options) + val writeBuilder = newWriteBuilder(table, options, query.schema) val write = writeBuilder match { case builder: SupportsDynamicOverwrite => builder.overwriteDynamicPartitions().build() @@ -85,12 +86,21 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { case WriteToMicroBatchDataSource( relation, table, query, queryId, writeOptions, outputMode, Some(batchId)) => - val writeBuilder = newWriteBuilder(table, query, writeOptions, queryId) + val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId) val write = buildWriteForMicroBatch(table, writeBuilder, outputMode) val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming) val customMetrics = write.supportedCustomMetrics.toSeq val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf) WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics) + + case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, None) => + val rowSchema = StructType.fromAttributes(rd.dataInput) + val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema) + val write = writeBuilder.build() + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf) + // project away any metadata columns that could be used for distribution and ordering + rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery)) + } private def buildWriteForMicroBatch( @@ -119,11 +129,11 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { private def newWriteBuilder( table: Table, - query: LogicalPlan, writeOptions: Map[String, String], + rowSchema: StructType, queryId: String = UUID.randomUUID().toString): WriteBuilder = { - val info = LogicalWriteInfoImpl(queryId, query.schema, writeOptions.asOptions) + val info = LogicalWriteInfoImpl(queryId, rowSchema, writeOptions.asOptions) table.asWritable.newWriteBuilder(info) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 65c49283dd7..d23a9e51f65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -284,6 +284,21 @@ case class OverwritePartitionsDynamicExec( copy(query = newChild) } +/** + * Physical plan node to replace data in existing tables. + */ +case class ReplaceDataExec( + query: SparkPlan, + refreshCache: () => Unit, + write: Write) extends V2ExistingTableWriteExec { + + override val stringArgs: Iterator[Any] = Iterator(query, write) + + override protected def withNewChildInternal(newChild: SparkPlan): ReplaceDataExec = { + copy(query = newChild) + } +} + case class WriteToDataSourceV2Exec( batchWrite: BatchWrite, refreshCache: () => Unit, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala new file mode 100644 index 00000000000..a2cfdde2671 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala @@ -0,0 +1,629 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util.Collections + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, QueryTest, Row} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryRowLevelOperationTableCatalog} +import org.apache.spark.sql.connector.expressions.LogicalExpressions._ +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.QueryExecutionListener + +abstract class DeleteFromTableSuiteBase + extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + import testImplicits._ + + before { + spark.conf.set("spark.sql.catalog.cat", classOf[InMemoryRowLevelOperationTableCatalog].getName) + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.unsetConf("spark.sql.catalog.cat") + } + + private val namespace = Array("ns1") + private val ident = Identifier.of(namespace, "test_table") + private val tableNameAsString = "cat." + ident.toString + + private def catalog: InMemoryRowLevelOperationTableCatalog = { + val catalog = spark.sessionState.catalogManager.catalog("cat") + catalog.asTableCatalog.asInstanceOf[InMemoryRowLevelOperationTableCatalog] + } + + test("EXPLAIN only delete") { + createAndInitTable("id INT, dep STRING", """{ "id": 1, "dep": "hr" }""") + + sql(s"EXPLAIN DELETE FROM $tableNameAsString WHERE id <= 10") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, "hr") :: Nil) + } + + test("delete from empty tables") { + createTable("id INT, dep STRING") + + sql(s"DELETE FROM $tableNameAsString WHERE id <= 1") + + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) + } + + test("delete with basic filters") { + createAndInitTable("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": 2, "dep": "software" } + |{ "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"DELETE FROM $tableNameAsString WHERE id <= 1") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, "software") :: Row(3, "hr") :: Nil) + } + + test("delete with aliases") { + createAndInitTable("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": 2, "dep": "software" } + |{ "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"DELETE FROM $tableNameAsString AS t WHERE t.id <= 1 OR t.dep = 'hr'") + + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(2, "software") :: Nil) + } + + test("delete with IN predicates") { + createAndInitTable("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": 2, "dep": "software" } + |{ "id": null, "dep": "hr" } + |""".stripMargin) + + sql(s"DELETE FROM $tableNameAsString WHERE id IN (1, null)") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, "software") :: Row(null, "hr") :: Nil) + } + + test("delete with NOT IN predicates") { + createAndInitTable("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": 2, "dep": "software" } + |{ "id": null, "dep": "hr" } + |""".stripMargin) + + sql(s"DELETE FROM $tableNameAsString WHERE id NOT IN (null, 1)") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, "hr") :: Row(2, "software") :: Row(null, "hr") :: Nil) + + sql(s"DELETE FROM $tableNameAsString WHERE id NOT IN (1, 10)") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, "hr") :: Row(null, "hr") :: Nil) + } + + test("delete with conditions on nested columns") { + createAndInitTable("id INT, complex STRUCT<c1:INT,c2:STRING>, dep STRING", + """{ "id": 1, "complex": { "c1": 3, "c2": "v1" }, "dep": "hr" } + |{ "id": 2, "complex": { "c1": 2, "c2": "v2" }, "dep": "software" } + |""".stripMargin) + + sql(s"DELETE FROM $tableNameAsString WHERE complex.c1 = id + 2") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, Row(2, "v2"), "software") :: Nil) + + sql(s"DELETE FROM $tableNameAsString t WHERE t.complex.c1 = id") + + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) + } + + test("delete with IN subqueries") { + withTempView("deleted_id", "deleted_dep") { + createAndInitTable("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": 2, "dep": "hardware" } + |{ "id": null, "dep": "hr" } + |""".stripMargin) + + val deletedIdDF = Seq(Some(0), Some(1), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + val deletedDepDF = Seq("software", "hr").toDF() + deletedDepDF.createOrReplaceTempView("deleted_dep") + + sql( + s"""DELETE FROM $tableNameAsString + |WHERE + | id IN (SELECT * FROM deleted_id) + | AND + | dep IN (SELECT * FROM deleted_dep) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, "hardware") :: Row(null, "hr") :: Nil) + + append("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": -1, "dep": "hr" } + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(-1, "hr") :: Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Nil) + + sql( + s"""DELETE FROM $tableNameAsString + |WHERE + | id IS NULL + | OR + | id IN (SELECT value + 2 FROM deleted_id) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(-1, "hr") :: Row(1, "hr") :: Nil) + + append("id INT, dep STRING", + """{ "id": null, "dep": "hr" } + |{ "id": 2, "dep": "hr" } + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(-1, "hr") :: Row(1, "hr") :: Row(2, "hr") :: Row(null, "hr") :: Nil) + + sql( + s"""DELETE FROM $tableNameAsString + |WHERE + | id IN (SELECT value + 2 FROM deleted_id) + | AND + | dep = 'hr' + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(-1, "hr") :: Row(1, "hr") :: Row(null, "hr") :: Nil) + } + } + + test("delete with multi-column IN subqueries") { + withTempView("deleted_employee") { + createAndInitTable("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": 2, "dep": "hardware" } + |{ "id": null, "dep": "hr" } + |""".stripMargin) + + val deletedEmployeeDF = Seq((None, "hr"), (Some(1), "hr")).toDF() + deletedEmployeeDF.createOrReplaceTempView("deleted_employee") + + sql( + s"""DELETE FROM $tableNameAsString + |WHERE + | (id, dep) IN (SELECT * FROM deleted_employee) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, "hardware") :: Row(null, "hr") :: Nil) + } + } + + test("delete with NOT IN subqueries") { + withTempView("deleted_id", "deleted_dep") { + createAndInitTable("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": 2, "dep": "hardware" } + |{ "id": null, "dep": "hr" } + |""".stripMargin) + + val deletedIdDF = Seq(Some(-1), Some(-2), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + val deletedDepDF = Seq("software", "hr").toDF() + deletedDepDF.createOrReplaceTempView("deleted_dep") + + sql( + s"""DELETE FROM $tableNameAsString + |WHERE + | id NOT IN (SELECT * FROM deleted_id) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Nil) + + sql( + s"""DELETE FROM $tableNameAsString + |WHERE + | id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) + |""".stripMargin) + + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(null, "hr") :: Nil) + + append("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": 2, "dep": "hardware" } + |{ "id": null, "dep": "hr" } + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Row(null, "hr") :: Nil) + + sql( + s"""DELETE FROM $tableNameAsString + |WHERE + | id NOT IN (SELECT * FROM deleted_id) + | OR + | dep IN ('software', 'hr') + |""".stripMargin) + + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(2, "hardware") :: Nil) + + sql( + s"""DELETE FROM $tableNameAsString + |WHERE + | id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) + | AND + | EXISTS (SELECT 1 FROM FROM deleted_dep WHERE dep = deleted_dep.value) + |""".stripMargin) + + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(2, "hardware") :: Nil) + + sql( + s"""DELETE FROM $tableNameAsString t + |WHERE + | t.id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) + | OR + | EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep = deleted_dep.value) + |""".stripMargin) + + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) + } + } + + test("delete with EXISTS subquery") { + withTempView("deleted_id", "deleted_dep") { + createAndInitTable("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": 2, "dep": "hardware" } + |{ "id": null, "dep": "hr" } + |""".stripMargin) + + val deletedIdDF = Seq(Some(-1), Some(-2), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + val deletedDepDF = Seq("software", "hr").toDF() + deletedDepDF.createOrReplaceTempView("deleted_dep") + + sql( + s"""DELETE FROM $tableNameAsString t + |WHERE + | EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, "hr") :: Row(2, "hardware") :: Row(null, "hr") :: Nil) + + sql( + s"""DELETE FROM $tableNameAsString t + |WHERE + | EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, "hardware") :: Row(null, "hr") :: Nil) + + sql( + s"""DELETE FROM $tableNameAsString t + |WHERE + | EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value) OR t.id IS NULL + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, "hardware") :: Nil) + + sql( + s"""DELETE FROM $tableNameAsString t + |WHERE + | EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value) + | AND + | EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, "hardware") :: Nil) + } + } + + test("delete with NOT EXISTS subquery") { + withTempView("deleted_id", "deleted_dep") { + createAndInitTable("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": 2, "dep": "hardware" } + |{ "id": null, "dep": "hr" } + |""".stripMargin) + + val deletedIdDF = Seq(Some(-1), Some(-2), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + val deletedDepDF = Seq("software", "hr").toDF() + deletedDepDF.createOrReplaceTempView("deleted_dep") + + sql( + s"""DELETE FROM $tableNameAsString t + |WHERE + | NOT EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value + 2) + | AND + | NOT EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, "hr") :: Row(null, "hr") :: Nil) + + sql( + s"""DELETE FROM $tableNameAsString t + |WHERE + | NOT EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2) + |""".stripMargin) + + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(1, "hr") :: Nil) + + sql( + s"""DELETE FROM $tableNameAsString t + |WHERE + | NOT EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2) + | OR + | t.id = 1 + |""".stripMargin) + + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) + } + } + + test("delete with a scalar subquery") { + withTempView("deleted_id") { + createAndInitTable("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": 2, "dep": "hardware" } + |{ "id": null, "dep": "hr" } + |""".stripMargin) + + val deletedIdDF = Seq(Some(1), Some(100), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + sql( + s"""DELETE FROM $tableNameAsString t + |WHERE + | id <= (SELECT min(value) FROM deleted_id) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, "hardware") :: Row(null, "hr") :: Nil) + } + } + + test("delete refreshes relation cache") { + withTempView("temp") { + withCache("temp") { + createAndInitTable("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": 1, "dep": "hardware" } + |{ "id": 2, "dep": "hardware" } + |{ "id": 3, "dep": "hr" } + |""".stripMargin) + + // define a view on top of the table + val query = sql(s"SELECT * FROM $tableNameAsString WHERE id = 1") + query.createOrReplaceTempView("temp") + + // cache the view + sql("CACHE TABLE temp") + + // verify the view returns expected results + checkAnswer( + sql("SELECT * FROM temp"), + Row(1, "hr") :: Row(1, "hardware") :: Nil) + + // delete some records from the table + sql(s"DELETE FROM $tableNameAsString WHERE id <= 1") + + // verify the delete was successful + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, "hardware") :: Row(3, "hr") :: Nil) + + // verify the view reflects the changes in the table + checkAnswer(sql("SELECT * FROM temp"), Nil) + } + } + } + + test("delete with nondeterministic conditions") { + createAndInitTable("id INT, dep STRING", + """{ "id": 1, "dep": "hr" } + |{ "id": 2, "dep": "software" } + |{ "id": 3, "dep": "hr" } + |""".stripMargin) + + val e = intercept[AnalysisException] { + sql(s"DELETE FROM $tableNameAsString WHERE id <= 1 AND rand() > 0.5") + } + assert(e.message.contains("nondeterministic expressions are only allowed")) + } + + test("delete without condition executed as delete with filters") { + createAndInitTable("id INT, dep INT", + """{ "id": 1, "dep": 100 } + |{ "id": 2, "dep": 200 } + |{ "id": 3, "dep": 100 } + |""".stripMargin) + + executeDeleteWithFilters(s"DELETE FROM $tableNameAsString") + + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) + } + + test("delete with supported predicates gets converted into delete with filters") { + createAndInitTable("id INT, dep INT", + """{ "id": 1, "dep": 100 } + |{ "id": 2, "dep": 200 } + |{ "id": 3, "dep": 100 } + |""".stripMargin) + + executeDeleteWithFilters(s"DELETE FROM $tableNameAsString WHERE dep = 100") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 200) :: Nil) + } + + test("delete with unsupported predicates cannot be converted into delete with filters") { + createAndInitTable("id INT, dep INT", + """{ "id": 1, "dep": 100 } + |{ "id": 2, "dep": 200 } + |{ "id": 3, "dep": 100 } + |""".stripMargin) + + executeDeleteWithRewrite(s"DELETE FROM $tableNameAsString WHERE dep = 100 OR dep < 200") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 200) :: Nil) + } + + test("delete with subquery cannot be converted into delete with filters") { + withTempView("deleted_id") { + createAndInitTable("id INT, dep INT", + """{ "id": 1, "dep": 100 } + |{ "id": 2, "dep": 200 } + |{ "id": 3, "dep": 100 } + |""".stripMargin) + + val deletedIdDF = Seq(Some(1), Some(100), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + val q = s"DELETE FROM $tableNameAsString WHERE dep = 100 AND id IN (SELECT * FROM deleted_id)" + executeDeleteWithRewrite(q) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 200) :: Row(3, 100) :: Nil) + } + } + + private def createTable(schemaString: String): Unit = { + val schema = StructType.fromDDL(schemaString) + val tableProps = Collections.emptyMap[String, String] + catalog.createTable(ident, schema, Array(identity(reference(Seq("dep")))), tableProps) + } + + private def createAndInitTable(schemaString: String, jsonData: String): Unit = { + createTable(schemaString) + append(schemaString, jsonData) + } + + private def append(schemaString: String, jsonData: String): Unit = { + val df = toDF(jsonData, schemaString) + df.coalesce(1).writeTo(tableNameAsString).append() + } + + private def toDF(jsonData: String, schemaString: String = null): DataFrame = { + val jsonRows = jsonData.split("\\n").filter(str => str.trim.nonEmpty) + val jsonDS = spark.createDataset(jsonRows)(Encoders.STRING) + if (schemaString == null) { + spark.read.json(jsonDS) + } else { + spark.read.schema(schemaString).json(jsonDS) + } + } + + private def executeDeleteWithFilters(query: String): Unit = { + val executedPlan = executeAndKeepPlan { + sql(query) + } + + executedPlan match { + case _: DeleteFromTableExec => + // OK + case other => + fail("unexpected executed plan: " + other) + } + } + + private def executeDeleteWithRewrite(query: String): Unit = { + val executedPlan = executeAndKeepPlan { + sql(query) + } + + executedPlan match { + case _: ReplaceDataExec => + // OK + case other => + fail("unexpected executed plan: " + other) + } + } + + // executes an operation and keeps the executed plan + private def executeAndKeepPlan(func: => Unit): SparkPlan = { + var executedPlan: SparkPlan = null + + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + executedPlan = qe.executedPlan + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + } + } + spark.listenerManager.register(listener) + + func + + sparkContext.listenerBus.waitUntilEmpty() + + stripAQEPlan(executedPlan) + } +} + +class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 24b6be07619..6a20ee21294 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTable, CreateTableAsSelect, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.FakeV2Provider -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table, TableCapability, TableCatalog, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, SupportsDelete, Table, TableCapability, TableCatalog, V1Table} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -49,7 +49,7 @@ class PlanResolutionSuite extends AnalysisTest { private val v2Format = classOf[FakeV2Provider].getName private val table: Table = { - val t = mock(classOf[Table]) + val t = mock(classOf[SupportsDelete]) when(t.schema()).thenReturn(new StructType().add("i", "int").add("s", "string")) when(t.partitioning()).thenReturn(Array.empty[Transform]) t --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org