This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1c6bd9e2698 [SPARK-38959][SQL] DS V2: Support runtime group filtering 
in row-level commands
1c6bd9e2698 is described below

commit 1c6bd9e2698aaaca8ccf84154328eb2fa0b484c2
Author: Anton Okolnychyi <aokolnyc...@apple.com>
AuthorDate: Tue Oct 11 13:41:30 2022 -0700

    [SPARK-38959][SQL] DS V2: Support runtime group filtering in row-level 
commands
    
    ### What changes were proposed in this pull request?
    
    This PR adds runtime group filtering for group-based row-level operations.
    
    ### Why are the changes needed?
    
    These changes are needed to avoid rewriting unnecessary groups as the data 
skipping during job planning is limited and can still report false positive 
groups to rewrite.
    
    ### Does this PR introduce _any_ user-facing change?
    
    This PR leverages existing APIs.
    
    ### How was this patch tested?
    
    This PR comes with tests.
    
    Closes #36304 from aokolnychyi/spark-38959.
    
    Lead-authored-by: Anton Okolnychyi <aokolnyc...@apple.com>
    Co-authored-by: aokolnychyi <aokolnyc...@apple.com>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 .../sql/connector/write/RowLevelOperation.java     |  14 ++
 .../org/apache/spark/sql/internal/SQLConf.scala    |  18 +++
 .../catalog/InMemoryRowLevelOperationTable.scala   |   6 +-
 .../spark/sql/execution/SparkOptimizer.scala       |   5 +-
 .../PlanAdaptiveDynamicPruningFilters.scala        |   2 +-
 .../dynamicpruning/PlanDynamicPruningFilters.scala |   2 +-
 .../RowLevelOperationRuntimeGroupFiltering.scala   |  98 ++++++++++++
 ...eSuite.scala => DeleteFromTableSuiteBase.scala} |  22 +--
 .../connector/GroupBasedDeleteFromTableSuite.scala | 166 +++++++++++++++++++++
 9 files changed, 318 insertions(+), 15 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java
index 7acd27759a1..844734ff7cc 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental;
 import org.apache.spark.sql.connector.expressions.NamedReference;
 import org.apache.spark.sql.connector.read.Scan;
 import org.apache.spark.sql.connector.read.ScanBuilder;
+import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering;
 import org.apache.spark.sql.util.CaseInsensitiveStringMap;
 
 /**
@@ -68,6 +69,19 @@ public interface RowLevelOperation {
    * be returned by the scan, even if a filter can narrow the set of changes 
to a single file
    * in the partition. Similarly, a data source that can swap individual files 
must produce all
    * rows from files where at least one record must be changed, not just rows 
that must be changed.
+   * <p>
+   * Data sources that replace groups of data (e.g. files, partitions) may 
prune entire groups
+   * using provided data source filters when building a scan for this 
row-level operation.
+   * However, such data skipping is limited as not all expressions can be 
converted into data source
+   * filters and some can only be evaluated by Spark (e.g. subqueries). Since 
rewriting groups is
+   * expensive, Spark allows group-based data sources to filter groups at 
runtime. The runtime
+   * filtering enables data sources to narrow down the scope of rewriting to 
only groups that must
+   * be rewritten. If the row-level operation scan implements {@link 
SupportsRuntimeV2Filtering},
+   * Spark will execute a query at runtime to find which records match the 
row-level condition.
+   * The runtime group filter subquery will leverage a regular batch scan, 
which isn't required to
+   * produce all rows in a group if any are returned. The information about 
matching records will
+   * be passed back into the row-level operation scan, allowing data sources 
to discard groups
+   * that don't have to be rewritten.
    */
   ScanBuilder newScanBuilder(CaseInsensitiveStringMap options);
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index bbe5bdd7035..1c981aa3950 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -412,6 +412,21 @@ object SQLConf {
       .longConf
       .createWithDefault(67108864L)
 
+  val RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED =
+    
buildConf("spark.sql.optimizer.runtime.rowLevelOperationGroupFilter.enabled")
+      .doc("Enables runtime group filtering for group-based row-level 
operations. " +
+        "Data sources that replace groups of data (e.g. files, partitions) may 
prune entire " +
+        "groups using provided data source filters when planning a row-level 
operation scan. " +
+        "However, such filtering is limited as not all expressions can be 
converted into data " +
+        "source filters and some expressions can only be evaluated by Spark 
(e.g. subqueries). " +
+        "Since rewriting groups is expensive, Spark can execute a query at 
runtime to find what " +
+        "records match the condition of the row-level operation. The 
information about matching " +
+        "records will be passed back to the row-level operation scan, allowing 
data sources to " +
+        "discard groups that don't have to be rewritten.")
+      .version("3.4.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val PLANNED_WRITE_ENABLED = 
buildConf("spark.sql.optimizer.plannedWrite.enabled")
     .internal()
     .doc("When set to true, Spark optimizer will add logical sort operators to 
V1 write commands " +
@@ -4091,6 +4106,9 @@ class SQLConf extends Serializable with Logging {
   def runtimeFilterCreationSideThreshold: Long =
     getConf(RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD)
 
+  def runtimeRowLevelOperationGroupFilterEnabled: Boolean =
+    getConf(RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED)
+
   def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS)
 
   def isStateSchemaCheckEnabled: Boolean = getConf(STATE_SCHEMA_CHECK_ENABLED)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
index cb061602ec1..08c22a02b85 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
@@ -34,6 +34,9 @@ class InMemoryRowLevelOperationTable(
     properties: util.Map[String, String])
   extends InMemoryTable(name, schema, partitioning, properties) with 
SupportsRowLevelOperations {
 
+  // used in row-level operation tests to verify replaced partitions
+  var replacedPartitions: Seq[Seq[Any]] = Seq.empty
+
   override def newRowLevelOperationBuilder(
       info: RowLevelOperationInfo): RowLevelOperationBuilder = {
     () => PartitionBasedOperation(info.command)
@@ -88,8 +91,9 @@ class InMemoryRowLevelOperationTable(
     override def commit(messages: Array[WriterCommitMessage]): Unit = 
dataMap.synchronized {
       val newData = messages.map(_.asInstanceOf[BufferedRows])
       val readRows = scan.data.flatMap(_.asInstanceOf[BufferedRows].rows)
-      val readPartitions = readRows.map(r => getKey(r, schema))
+      val readPartitions = readRows.map(r => getKey(r, schema)).distinct
       dataMap --= readPartitions
+      replacedPartitions = readPartitions
       withData(newData, schema)
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 72bdab409a9..017d1f937c3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.catalog.CatalogManager
 import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, 
SchemaPruning, V1Writes}
 import 
org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning,
 OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, 
V2ScanRelationPushDown, V2Writes}
-import 
org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, 
PartitionPruning}
+import 
org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, 
PartitionPruning, RowLevelOperationRuntimeGroupFiltering}
 import 
org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, 
ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
 
 class SparkOptimizer(
@@ -50,7 +50,8 @@ class SparkOptimizer(
   override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ 
super.defaultBatches :+
     Batch("Optimize Metadata Only Query", Once, 
OptimizeMetadataOnlyQuery(catalog)) :+
     Batch("PartitionPruning", Once,
-      PartitionPruning) :+
+      PartitionPruning,
+      RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)) :+
     Batch("InjectRuntimeFilter", FixedPoint(1),
       InjectRuntimeFilter) :+
     Batch("MergeScalarSubqueries", Once,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala
index 9a780c11eef..21bc55110fe 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala
@@ -32,7 +32,7 @@ import 
org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashedRelati
 case class PlanAdaptiveDynamicPruningFilters(
     rootPlan: AdaptiveSparkPlanExec) extends Rule[SparkPlan] with 
AdaptiveSparkPlanHelper {
   def apply(plan: SparkPlan): SparkPlan = {
-    if (!conf.dynamicPartitionPruningEnabled) {
+    if (!conf.dynamicPartitionPruningEnabled && 
!conf.runtimeRowLevelOperationGroupFilterEnabled) {
       return plan
     }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
index c9ff28eb045..df5e3ea1365 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
@@ -45,7 +45,7 @@ case class PlanDynamicPruningFilters(sparkSession: 
SparkSession) extends Rule[Sp
   }
 
   override def apply(plan: SparkPlan): SparkPlan = {
-    if (!conf.dynamicPartitionPruningEnabled) {
+    if (!conf.dynamicPartitionPruningEnabled && 
!conf.runtimeRowLevelOperationGroupFilterEnabled) {
       return plan
     }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala
new file mode 100644
index 00000000000..232c320bcd4
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.dynamicpruning
+
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, 
DynamicPruningSubquery, Expression, PredicateHelper, V2ExpressionUtils}
+import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
+import org.apache.spark.sql.catalyst.planning.GroupBasedRowLevelOperation
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, 
DataSourceV2Relation, DataSourceV2ScanRelation}
+
+/**
+ * A rule that assigns a subquery to filter groups in row-level operations at 
runtime.
+ *
+ * Data skipping during job planning for row-level operations is limited to 
expressions that can be
+ * converted to data source filters. Since not all expressions can be pushed 
down that way and
+ * rewriting groups is expensive, Spark allows data sources to filter group at 
runtime.
+ * If the primary scan in a group-based row-level operation supports runtime 
filtering, this rule
+ * will inject a subquery to find all rows that match the condition so that 
data sources know
+ * exactly which groups must be rewritten.
+ *
+ * Note this rule only applies to group-based row-level operations.
+ */
+case class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: 
Rule[LogicalPlan])
+  extends Rule[LogicalPlan] with PredicateHelper {
+
+  import DataSourceV2Implicits._
+
+  override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+    // apply special dynamic filtering only for group-based row-level 
operations
+    case GroupBasedRowLevelOperation(replaceData, cond,
+        DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _))
+        if conf.runtimeRowLevelOperationGroupFilterEnabled && cond != 
TrueLiteral =>
+
+      // use reference equality on scan to find required scan relations
+      val newQuery = replaceData.query transformUp {
+        case r: DataSourceV2ScanRelation if r.scan eq scan =>
+          // use the original table instance that was loaded for this 
row-level operation
+          // in order to leverage a regular batch scan in the group filter 
query
+          val originalTable = r.relation.table.asRowLevelOperationTable.table
+          val relation = r.relation.copy(table = originalTable)
+          val matchingRowsPlan = buildMatchingRowsPlan(relation, cond)
+
+          val filterAttrs = scan.filterAttributes
+          val buildKeys = 
V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, matchingRowsPlan)
+          val pruningKeys = 
V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, r)
+          val dynamicPruningCond = buildDynamicPruningCond(matchingRowsPlan, 
buildKeys, pruningKeys)
+
+          Filter(dynamicPruningCond, r)
+      }
+
+      // optimize subqueries to rewrite them as joins and trigger job planning
+      replaceData.copy(query = optimizeSubqueries(newQuery))
+  }
+
+  private def buildMatchingRowsPlan(
+      relation: DataSourceV2Relation,
+      cond: Expression): LogicalPlan = {
+
+    val matchingRowsPlan = Filter(cond, relation)
+
+    // clone the relation and assign new expr IDs to avoid conflicts
+    matchingRowsPlan transformUpWithNewOutput {
+      case r: DataSourceV2Relation if r eq relation =>
+        val oldOutput = r.output
+        val newOutput = oldOutput.map(_.newInstance())
+        r.copy(output = newOutput) -> oldOutput.zip(newOutput)
+    }
+  }
+
+  private def buildDynamicPruningCond(
+      matchingRowsPlan: LogicalPlan,
+      buildKeys: Seq[Attribute],
+      pruningKeys: Seq[Attribute]): Expression = {
+
+    val buildQuery = Project(buildKeys, matchingRowsPlan)
+    val dynamicPruningSubqueries = pruningKeys.zipWithIndex.map { case (key, 
index) =>
+      DynamicPruningSubquery(key, buildQuery, buildKeys, index, 
onlyInBroadcast = false)
+    }
+    dynamicPruningSubqueries.reduce(And)
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala
similarity index 96%
rename from 
sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala
rename to 
sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala
index a2cfdde2671..d9a12b47ec2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala
@@ -22,7 +22,7 @@ import java.util.Collections
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, 
QueryTest, Row}
-import org.apache.spark.sql.connector.catalog.{Identifier, 
InMemoryRowLevelOperationTableCatalog}
+import org.apache.spark.sql.connector.catalog.{Identifier, 
InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog}
 import org.apache.spark.sql.connector.expressions.LogicalExpressions._
 import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -46,15 +46,19 @@ abstract class DeleteFromTableSuiteBase
     spark.sessionState.conf.unsetConf("spark.sql.catalog.cat")
   }
 
-  private val namespace = Array("ns1")
-  private val ident = Identifier.of(namespace, "test_table")
-  private val tableNameAsString = "cat." + ident.toString
+  protected val namespace: Array[String] = Array("ns1")
+  protected val ident: Identifier = Identifier.of(namespace, "test_table")
+  protected val tableNameAsString: String = "cat." + ident.toString
 
-  private def catalog: InMemoryRowLevelOperationTableCatalog = {
+  protected def catalog: InMemoryRowLevelOperationTableCatalog = {
     val catalog = spark.sessionState.catalogManager.catalog("cat")
     catalog.asTableCatalog.asInstanceOf[InMemoryRowLevelOperationTableCatalog]
   }
 
+  protected def table: InMemoryRowLevelOperationTable = {
+    catalog.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable]
+  }
+
   test("EXPLAIN only delete") {
     createAndInitTable("id INT, dep STRING", """{ "id": 1, "dep": "hr" }""")
 
@@ -553,13 +557,13 @@ abstract class DeleteFromTableSuiteBase
     }
   }
 
-  private def createTable(schemaString: String): Unit = {
+  protected def createTable(schemaString: String): Unit = {
     val schema = StructType.fromDDL(schemaString)
     val tableProps = Collections.emptyMap[String, String]
     catalog.createTable(ident, schema, Array(identity(reference(Seq("dep")))), 
tableProps)
   }
 
-  private def createAndInitTable(schemaString: String, jsonData: String): Unit 
= {
+  protected def createAndInitTable(schemaString: String, jsonData: String): 
Unit = {
     createTable(schemaString)
     append(schemaString, jsonData)
   }
@@ -606,7 +610,7 @@ abstract class DeleteFromTableSuiteBase
   }
 
   // executes an operation and keeps the executed plan
-  private def executeAndKeepPlan(func: => Unit): SparkPlan = {
+  protected def executeAndKeepPlan(func: => Unit): SparkPlan = {
     var executedPlan: SparkPlan = null
 
     val listener = new QueryExecutionListener {
@@ -625,5 +629,3 @@ abstract class DeleteFromTableSuiteBase
     stripAQEPlan(executedPlan)
   }
 }
-
-class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
new file mode 100644
index 00000000000..36905027cb0
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression
+import org.apache.spark.sql.execution.InSubqueryExec
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.UTF8String
+
+class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase {
+
+  import testImplicits._
+
+  test("delete with IN predicate and runtime group filtering") {
+    createAndInitTable("id INT, salary INT, dep STRING",
+      """{ "id": 1, "salary": 300, "dep": 'hr' }
+        |{ "id": 2, "salary": 150, "dep": 'software' }
+        |{ "id": 3, "salary": 120, "dep": 'hr' }
+        |""".stripMargin)
+
+    executeDeleteAndCheckScans(
+      s"DELETE FROM $tableNameAsString WHERE salary IN (300, 400, 500)",
+      primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING",
+      groupFilterScanSchema = "salary INT, dep STRING")
+
+    checkAnswer(
+      sql(s"SELECT * FROM $tableNameAsString"),
+      Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil)
+
+    checkReplacedPartitions(Seq("hr"))
+  }
+
+  test("delete with subqueries and runtime group filtering") {
+    withTempView("deleted_id", "deleted_dep") {
+      createAndInitTable("id INT, salary INT, dep STRING",
+        """{ "id": 1, "salary": 300, "dep": 'hr' }
+          |{ "id": 2, "salary": 150, "dep": 'software' }
+          |{ "id": 3, "salary": 120, "dep": 'hr' }
+          |{ "id": 4, "salary": 150, "dep": 'software' }
+          |""".stripMargin)
+
+      val deletedIdDF = Seq(Some(2), None).toDF()
+      deletedIdDF.createOrReplaceTempView("deleted_id")
+
+      val deletedDepDF = Seq(Some("software"), None).toDF()
+      deletedDepDF.createOrReplaceTempView("deleted_dep")
+
+      executeDeleteAndCheckScans(
+        s"""DELETE FROM $tableNameAsString
+           |WHERE
+           | id IN (SELECT * FROM deleted_id)
+           | AND
+           | dep IN (SELECT * FROM deleted_dep)
+           |""".stripMargin,
+        primaryScanSchema = "id INT, salary INT, dep STRING, _partition 
STRING",
+        groupFilterScanSchema = "id INT, dep STRING")
+
+      checkAnswer(
+        sql(s"SELECT * FROM $tableNameAsString"),
+        Row(1, 300, "hr") :: Row(3, 120, "hr") :: Row(4, 150, "software") :: 
Nil)
+
+      checkReplacedPartitions(Seq("software"))
+    }
+  }
+
+  test("delete runtime group filtering (DPP enabled)") {
+    withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
+      checkDeleteRuntimeGroupFiltering()
+    }
+  }
+
+  test("delete runtime group filtering (DPP disabled)") {
+    withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "false") {
+      checkDeleteRuntimeGroupFiltering()
+    }
+  }
+
+  test("delete runtime group filtering (AQE enabled)") {
+    withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+      checkDeleteRuntimeGroupFiltering()
+    }
+  }
+
+  test("delete runtime group filtering (AQE disabled)") {
+    withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+      checkDeleteRuntimeGroupFiltering()
+    }
+  }
+
+  private def checkDeleteRuntimeGroupFiltering(): Unit = {
+    withTempView("deleted_id") {
+      createAndInitTable("id INT, salary INT, dep STRING",
+        """{ "id": 1, "salary": 300, "dep": 'hr' }
+          |{ "id": 2, "salary": 150, "dep": 'software' }
+          |{ "id": 3, "salary": 120, "dep": 'hr' }
+          |""".stripMargin)
+
+      val deletedIdDF = Seq(Some(1), None).toDF()
+      deletedIdDF.createOrReplaceTempView("deleted_id")
+
+      executeDeleteAndCheckScans(
+        s"DELETE FROM $tableNameAsString WHERE id IN (SELECT * FROM 
deleted_id)",
+        primaryScanSchema = "id INT, salary INT, dep STRING, _partition 
STRING",
+        groupFilterScanSchema = "id INT, dep STRING")
+
+      checkAnswer(
+        sql(s"SELECT * FROM $tableNameAsString"),
+        Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil)
+
+      checkReplacedPartitions(Seq("hr"))
+    }
+  }
+
+  private def executeDeleteAndCheckScans(
+      query: String,
+      primaryScanSchema: String,
+      groupFilterScanSchema: String): Unit = {
+
+    val executedPlan = executeAndKeepPlan {
+      sql(query)
+    }
+
+    val primaryScan = collect(executedPlan) {
+      case s: BatchScanExec => s
+    }.head
+    assert(primaryScan.schema.sameType(StructType.fromDDL(primaryScanSchema)))
+
+    primaryScan.runtimeFilters match {
+      case Seq(DynamicPruningExpression(child: InSubqueryExec)) =>
+        val groupFilterScan = collect(child.plan) {
+          case s: BatchScanExec => s
+        }.head
+        
assert(groupFilterScan.schema.sameType(StructType.fromDDL(groupFilterScanSchema)))
+
+      case _ =>
+        fail("could not find group filter scan")
+    }
+  }
+
+  private def checkReplacedPartitions(expectedPartitions: Seq[Any]): Unit = {
+    val actualPartitions = table.replacedPartitions.map {
+      case Seq(partValue: UTF8String) => partValue.toString
+      case Seq(partValue) => partValue
+      case other => fail(s"expected only one partition value: $other" )
+    }
+    assert(actualPartitions == expectedPartitions, "replaced partitions must 
match")
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to