This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1c6bd9e2698 [SPARK-38959][SQL] DS V2: Support runtime group filtering in row-level commands 1c6bd9e2698 is described below commit 1c6bd9e2698aaaca8ccf84154328eb2fa0b484c2 Author: Anton Okolnychyi <aokolnyc...@apple.com> AuthorDate: Tue Oct 11 13:41:30 2022 -0700 [SPARK-38959][SQL] DS V2: Support runtime group filtering in row-level commands ### What changes were proposed in this pull request? This PR adds runtime group filtering for group-based row-level operations. ### Why are the changes needed? These changes are needed to avoid rewriting unnecessary groups as the data skipping during job planning is limited and can still report false positive groups to rewrite. ### Does this PR introduce _any_ user-facing change? This PR leverages existing APIs. ### How was this patch tested? This PR comes with tests. Closes #36304 from aokolnychyi/spark-38959. Lead-authored-by: Anton Okolnychyi <aokolnyc...@apple.com> Co-authored-by: aokolnychyi <aokolnyc...@apple.com> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- .../sql/connector/write/RowLevelOperation.java | 14 ++ .../org/apache/spark/sql/internal/SQLConf.scala | 18 +++ .../catalog/InMemoryRowLevelOperationTable.scala | 6 +- .../spark/sql/execution/SparkOptimizer.scala | 5 +- .../PlanAdaptiveDynamicPruningFilters.scala | 2 +- .../dynamicpruning/PlanDynamicPruningFilters.scala | 2 +- .../RowLevelOperationRuntimeGroupFiltering.scala | 98 ++++++++++++ ...eSuite.scala => DeleteFromTableSuiteBase.scala} | 22 +-- .../connector/GroupBasedDeleteFromTableSuite.scala | 166 +++++++++++++++++++++ 9 files changed, 318 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java index 7acd27759a1..844734ff7cc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java @@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering; import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** @@ -68,6 +69,19 @@ public interface RowLevelOperation { * be returned by the scan, even if a filter can narrow the set of changes to a single file * in the partition. Similarly, a data source that can swap individual files must produce all * rows from files where at least one record must be changed, not just rows that must be changed. + * <p> + * Data sources that replace groups of data (e.g. files, partitions) may prune entire groups + * using provided data source filters when building a scan for this row-level operation. + * However, such data skipping is limited as not all expressions can be converted into data source + * filters and some can only be evaluated by Spark (e.g. subqueries). Since rewriting groups is + * expensive, Spark allows group-based data sources to filter groups at runtime. The runtime + * filtering enables data sources to narrow down the scope of rewriting to only groups that must + * be rewritten. If the row-level operation scan implements {@link SupportsRuntimeV2Filtering}, + * Spark will execute a query at runtime to find which records match the row-level condition. + * The runtime group filter subquery will leverage a regular batch scan, which isn't required to + * produce all rows in a group if any are returned. The information about matching records will + * be passed back into the row-level operation scan, allowing data sources to discard groups + * that don't have to be rewritten. */ ScanBuilder newScanBuilder(CaseInsensitiveStringMap options); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bbe5bdd7035..1c981aa3950 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -412,6 +412,21 @@ object SQLConf { .longConf .createWithDefault(67108864L) + val RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED = + buildConf("spark.sql.optimizer.runtime.rowLevelOperationGroupFilter.enabled") + .doc("Enables runtime group filtering for group-based row-level operations. " + + "Data sources that replace groups of data (e.g. files, partitions) may prune entire " + + "groups using provided data source filters when planning a row-level operation scan. " + + "However, such filtering is limited as not all expressions can be converted into data " + + "source filters and some expressions can only be evaluated by Spark (e.g. subqueries). " + + "Since rewriting groups is expensive, Spark can execute a query at runtime to find what " + + "records match the condition of the row-level operation. The information about matching " + + "records will be passed back to the row-level operation scan, allowing data sources to " + + "discard groups that don't have to be rewritten.") + .version("3.4.0") + .booleanConf + .createWithDefault(true) + val PLANNED_WRITE_ENABLED = buildConf("spark.sql.optimizer.plannedWrite.enabled") .internal() .doc("When set to true, Spark optimizer will add logical sort operators to V1 write commands " + @@ -4091,6 +4106,9 @@ class SQLConf extends Serializable with Logging { def runtimeFilterCreationSideThreshold: Long = getConf(RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD) + def runtimeRowLevelOperationGroupFilterEnabled: Boolean = + getConf(RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED) + def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) def isStateSchemaCheckEnabled: Boolean = getConf(STATE_SCHEMA_CHECK_ENABLED) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index cb061602ec1..08c22a02b85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -34,6 +34,9 @@ class InMemoryRowLevelOperationTable( properties: util.Map[String, String]) extends InMemoryTable(name, schema, partitioning, properties) with SupportsRowLevelOperations { + // used in row-level operation tests to verify replaced partitions + var replacedPartitions: Seq[Seq[Any]] = Seq.empty + override def newRowLevelOperationBuilder( info: RowLevelOperationInfo): RowLevelOperationBuilder = { () => PartitionBasedOperation(info.command) @@ -88,8 +91,9 @@ class InMemoryRowLevelOperationTable( override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val newData = messages.map(_.asInstanceOf[BufferedRows]) val readRows = scan.data.flatMap(_.asInstanceOf[BufferedRows].rows) - val readPartitions = readRows.map(r => getKey(r, schema)) + val readPartitions = readRows.map(r => getKey(r, schema)).distinct dataMap --= readPartitions + replacedPartitions = readPartitions withData(newData, schema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 72bdab409a9..017d1f937c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes} import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes} -import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning} +import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning, RowLevelOperationRuntimeGroupFiltering} import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs} class SparkOptimizer( @@ -50,7 +50,8 @@ class SparkOptimizer( override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("PartitionPruning", Once, - PartitionPruning) :+ + PartitionPruning, + RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)) :+ Batch("InjectRuntimeFilter", FixedPoint(1), InjectRuntimeFilter) :+ Batch("MergeScalarSubqueries", Once, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala index 9a780c11eef..21bc55110fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashedRelati case class PlanAdaptiveDynamicPruningFilters( rootPlan: AdaptiveSparkPlanExec) extends Rule[SparkPlan] with AdaptiveSparkPlanHelper { def apply(plan: SparkPlan): SparkPlan = { - if (!conf.dynamicPartitionPruningEnabled) { + if (!conf.dynamicPartitionPruningEnabled && !conf.runtimeRowLevelOperationGroupFilterEnabled) { return plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index c9ff28eb045..df5e3ea1365 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -45,7 +45,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp } override def apply(plan: SparkPlan): SparkPlan = { - if (!conf.dynamicPartitionPruningEnabled) { + if (!conf.dynamicPartitionPruningEnabled && !conf.runtimeRowLevelOperationGroupFilterEnabled) { return plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala new file mode 100644 index 00000000000..232c320bcd4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.dynamicpruning + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruningSubquery, Expression, PredicateHelper, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.planning.GroupBasedRowLevelOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation, DataSourceV2ScanRelation} + +/** + * A rule that assigns a subquery to filter groups in row-level operations at runtime. + * + * Data skipping during job planning for row-level operations is limited to expressions that can be + * converted to data source filters. Since not all expressions can be pushed down that way and + * rewriting groups is expensive, Spark allows data sources to filter group at runtime. + * If the primary scan in a group-based row-level operation supports runtime filtering, this rule + * will inject a subquery to find all rows that match the condition so that data sources know + * exactly which groups must be rewritten. + * + * Note this rule only applies to group-based row-level operations. + */ +case class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPlan]) + extends Rule[LogicalPlan] with PredicateHelper { + + import DataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + // apply special dynamic filtering only for group-based row-level operations + case GroupBasedRowLevelOperation(replaceData, cond, + DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _)) + if conf.runtimeRowLevelOperationGroupFilterEnabled && cond != TrueLiteral => + + // use reference equality on scan to find required scan relations + val newQuery = replaceData.query transformUp { + case r: DataSourceV2ScanRelation if r.scan eq scan => + // use the original table instance that was loaded for this row-level operation + // in order to leverage a regular batch scan in the group filter query + val originalTable = r.relation.table.asRowLevelOperationTable.table + val relation = r.relation.copy(table = originalTable) + val matchingRowsPlan = buildMatchingRowsPlan(relation, cond) + + val filterAttrs = scan.filterAttributes + val buildKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, matchingRowsPlan) + val pruningKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, r) + val dynamicPruningCond = buildDynamicPruningCond(matchingRowsPlan, buildKeys, pruningKeys) + + Filter(dynamicPruningCond, r) + } + + // optimize subqueries to rewrite them as joins and trigger job planning + replaceData.copy(query = optimizeSubqueries(newQuery)) + } + + private def buildMatchingRowsPlan( + relation: DataSourceV2Relation, + cond: Expression): LogicalPlan = { + + val matchingRowsPlan = Filter(cond, relation) + + // clone the relation and assign new expr IDs to avoid conflicts + matchingRowsPlan transformUpWithNewOutput { + case r: DataSourceV2Relation if r eq relation => + val oldOutput = r.output + val newOutput = oldOutput.map(_.newInstance()) + r.copy(output = newOutput) -> oldOutput.zip(newOutput) + } + } + + private def buildDynamicPruningCond( + matchingRowsPlan: LogicalPlan, + buildKeys: Seq[Attribute], + pruningKeys: Seq[Attribute]): Expression = { + + val buildQuery = Project(buildKeys, matchingRowsPlan) + val dynamicPruningSubqueries = pruningKeys.zipWithIndex.map { case (key, index) => + DynamicPruningSubquery(key, buildQuery, buildKeys, index, onlyInBroadcast = false) + } + dynamicPruningSubqueries.reduce(And) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index a2cfdde2671..d9a12b47ec2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -22,7 +22,7 @@ import java.util.Collections import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, QueryTest, Row} -import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryRowLevelOperationTableCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog} import org.apache.spark.sql.connector.expressions.LogicalExpressions._ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -46,15 +46,19 @@ abstract class DeleteFromTableSuiteBase spark.sessionState.conf.unsetConf("spark.sql.catalog.cat") } - private val namespace = Array("ns1") - private val ident = Identifier.of(namespace, "test_table") - private val tableNameAsString = "cat." + ident.toString + protected val namespace: Array[String] = Array("ns1") + protected val ident: Identifier = Identifier.of(namespace, "test_table") + protected val tableNameAsString: String = "cat." + ident.toString - private def catalog: InMemoryRowLevelOperationTableCatalog = { + protected def catalog: InMemoryRowLevelOperationTableCatalog = { val catalog = spark.sessionState.catalogManager.catalog("cat") catalog.asTableCatalog.asInstanceOf[InMemoryRowLevelOperationTableCatalog] } + protected def table: InMemoryRowLevelOperationTable = { + catalog.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] + } + test("EXPLAIN only delete") { createAndInitTable("id INT, dep STRING", """{ "id": 1, "dep": "hr" }""") @@ -553,13 +557,13 @@ abstract class DeleteFromTableSuiteBase } } - private def createTable(schemaString: String): Unit = { + protected def createTable(schemaString: String): Unit = { val schema = StructType.fromDDL(schemaString) val tableProps = Collections.emptyMap[String, String] catalog.createTable(ident, schema, Array(identity(reference(Seq("dep")))), tableProps) } - private def createAndInitTable(schemaString: String, jsonData: String): Unit = { + protected def createAndInitTable(schemaString: String, jsonData: String): Unit = { createTable(schemaString) append(schemaString, jsonData) } @@ -606,7 +610,7 @@ abstract class DeleteFromTableSuiteBase } // executes an operation and keeps the executed plan - private def executeAndKeepPlan(func: => Unit): SparkPlan = { + protected def executeAndKeepPlan(func: => Unit): SparkPlan = { var executedPlan: SparkPlan = null val listener = new QueryExecutionListener { @@ -625,5 +629,3 @@ abstract class DeleteFromTableSuiteBase stripAQEPlan(executedPlan) } } - -class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala new file mode 100644 index 00000000000..36905027cb0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression +import org.apache.spark.sql.execution.InSubqueryExec +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { + + import testImplicits._ + + test("delete with IN predicate and runtime group filtering") { + createAndInitTable("id INT, salary INT, dep STRING", + """{ "id": 1, "salary": 300, "dep": 'hr' } + |{ "id": 2, "salary": 150, "dep": 'software' } + |{ "id": 3, "salary": 120, "dep": 'hr' } + |""".stripMargin) + + executeDeleteAndCheckScans( + s"DELETE FROM $tableNameAsString WHERE salary IN (300, 400, 500)", + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = "salary INT, dep STRING") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) + + checkReplacedPartitions(Seq("hr")) + } + + test("delete with subqueries and runtime group filtering") { + withTempView("deleted_id", "deleted_dep") { + createAndInitTable("id INT, salary INT, dep STRING", + """{ "id": 1, "salary": 300, "dep": 'hr' } + |{ "id": 2, "salary": 150, "dep": 'software' } + |{ "id": 3, "salary": 120, "dep": 'hr' } + |{ "id": 4, "salary": 150, "dep": 'software' } + |""".stripMargin) + + val deletedIdDF = Seq(Some(2), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + val deletedDepDF = Seq(Some("software"), None).toDF() + deletedDepDF.createOrReplaceTempView("deleted_dep") + + executeDeleteAndCheckScans( + s"""DELETE FROM $tableNameAsString + |WHERE + | id IN (SELECT * FROM deleted_id) + | AND + | dep IN (SELECT * FROM deleted_dep) + |""".stripMargin, + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = "id INT, dep STRING") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, 300, "hr") :: Row(3, 120, "hr") :: Row(4, 150, "software") :: Nil) + + checkReplacedPartitions(Seq("software")) + } + } + + test("delete runtime group filtering (DPP enabled)") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { + checkDeleteRuntimeGroupFiltering() + } + } + + test("delete runtime group filtering (DPP disabled)") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "false") { + checkDeleteRuntimeGroupFiltering() + } + } + + test("delete runtime group filtering (AQE enabled)") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + checkDeleteRuntimeGroupFiltering() + } + } + + test("delete runtime group filtering (AQE disabled)") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + checkDeleteRuntimeGroupFiltering() + } + } + + private def checkDeleteRuntimeGroupFiltering(): Unit = { + withTempView("deleted_id") { + createAndInitTable("id INT, salary INT, dep STRING", + """{ "id": 1, "salary": 300, "dep": 'hr' } + |{ "id": 2, "salary": 150, "dep": 'software' } + |{ "id": 3, "salary": 120, "dep": 'hr' } + |""".stripMargin) + + val deletedIdDF = Seq(Some(1), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + executeDeleteAndCheckScans( + s"DELETE FROM $tableNameAsString WHERE id IN (SELECT * FROM deleted_id)", + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = "id INT, dep STRING") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) + + checkReplacedPartitions(Seq("hr")) + } + } + + private def executeDeleteAndCheckScans( + query: String, + primaryScanSchema: String, + groupFilterScanSchema: String): Unit = { + + val executedPlan = executeAndKeepPlan { + sql(query) + } + + val primaryScan = collect(executedPlan) { + case s: BatchScanExec => s + }.head + assert(primaryScan.schema.sameType(StructType.fromDDL(primaryScanSchema))) + + primaryScan.runtimeFilters match { + case Seq(DynamicPruningExpression(child: InSubqueryExec)) => + val groupFilterScan = collect(child.plan) { + case s: BatchScanExec => s + }.head + assert(groupFilterScan.schema.sameType(StructType.fromDDL(groupFilterScanSchema))) + + case _ => + fail("could not find group filter scan") + } + } + + private def checkReplacedPartitions(expectedPartitions: Seq[Any]): Unit = { + val actualPartitions = table.replacedPartitions.map { + case Seq(partValue: UTF8String) => partValue.toString + case Seq(partValue) => partValue + case other => fail(s"expected only one partition value: $other" ) + } + assert(actualPartitions == expectedPartitions, "replaced partitions must match") + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org