gengliangwang commented on code in PR #55887:
URL: https://github.com/apache/spark/pull/55887#discussion_r3246275170


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala:
##########
@@ -187,6 +193,90 @@ object PushDownUtils extends Logging {
     }
   }
 
+  /**
+   * Pushes runtime filters into `scan` and re-plans its input partitions. For 
scans whose
+   * `outputPartitioning` is a [[KeyedPartitioning]] (SPJ-active), validates 
that the data source
+   * preserved the original partitioning and pads with `None` to preserve key 
alignment with the
+   * pre-filter partition set.
+   *
+   * Must be called at execute time: runtime filters carry 
[[DynamicPruningExpression]] and
+   * scalar-subquery references whose values are only resolved after their 
broadcast/subquery
+   * side completes. Callers should wrap the result in a `lazy val` so the 
mutating
+   * [[pushRuntimeFilters]] call runs at most once per scan instance.
+   *
+   * @param scan                      the V2 scan to push filters into
+   * @param runtimeFilters            runtime filters to translate and push
+   * @param partitionPredicateSchema  by-name schema for iterative 
[[PartitionPredicate]] pushdown
+   * @param output                    scan output attributes
+   * @param outputPartitioning        Spark-side output partitioning (used for 
SPJ validation)
+   * @param inputPartitions           by-name original (unfiltered) 
partitions; consulted only when
+   *                                  no runtime filters fire, so callers can 
compute it lazily
+   * @return one entry per original input partition: `Some(part)` for 
surviving partitions and
+   *         `None` for partition keys whose splits were entirely pruned (SPJ 
alignment)
+   */
+  def filterAndPlanPartitions(

Review Comment:
   Could you link the follow-up PR/JIRA for the alternative scan operator that 
motivates this extraction? Without seeing the second caller it's hard to 
validate the parameter shape (by-name `partitionPredicateSchema`, by-name 
`inputPartitions`, and the new transforms-based `getPartitionPredicateSchema` 
overload) — they look defensive for `BatchScanExec` but are presumably 
load-bearing for the upcoming caller.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala:
##########
@@ -139,12 +141,16 @@ object PushDownUtils extends Logging {
    * the first pass are used to derive PartitionPredicates in the second pass, 
avoiding duplicate
    * pushdown.
    *
+   * The partition-predicate schema is passed by-name so callers that cannot 
supply one (no
+   * partition transforms available) or whose scan does not opt into iterative 
pushdown pay no
+   * derivation cost.
+   *
    * @return true if any filters were pushed to the data source
    */
   def pushRuntimeFilters(
       scan: Scan,
       runtimeFilters: Seq[Expression],
-      table: Table,
+      partitionPredicateSchema: => Option[Seq[PartitionPredicateField]],

Review Comment:
   Previously this method took `table: Table` and computed the schema 
internally; now every caller has to call `getPartitionPredicateSchema(table, 
output)` themselves and pass the result. For callers that have a `Table` (the 
only existing caller today), that's two calls where there used to be one. 
Consider keeping a thin `Table`-accepting overload as the ergonomic default, 
with this `partitionPredicateSchema`-accepting form for callers that don't have 
a `Table`.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala:
##########
@@ -187,6 +193,90 @@ object PushDownUtils extends Logging {
     }
   }
 
+  /**
+   * Pushes runtime filters into `scan` and re-plans its input partitions. For 
scans whose
+   * `outputPartitioning` is a [[KeyedPartitioning]] (SPJ-active), validates 
that the data source
+   * preserved the original partitioning and pads with `None` to preserve key 
alignment with the
+   * pre-filter partition set.
+   *
+   * Must be called at execute time: runtime filters carry 
[[DynamicPruningExpression]] and
+   * scalar-subquery references whose values are only resolved after their 
broadcast/subquery
+   * side completes. Callers should wrap the result in a `lazy val` so the 
mutating
+   * [[pushRuntimeFilters]] call runs at most once per scan instance.
+   *
+   * @param scan                      the V2 scan to push filters into
+   * @param runtimeFilters            runtime filters to translate and push
+   * @param partitionPredicateSchema  by-name schema for iterative 
[[PartitionPredicate]] pushdown
+   * @param output                    scan output attributes
+   * @param outputPartitioning        Spark-side output partitioning (used for 
SPJ validation)
+   * @param inputPartitions           by-name original (unfiltered) 
partitions; consulted only when
+   *                                  no runtime filters fire, so callers can 
compute it lazily
+   * @return one entry per original input partition: `Some(part)` for 
surviving partitions and
+   *         `None` for partition keys whose splits were entirely pruned (SPJ 
alignment)
+   */
+  def filterAndPlanPartitions(
+      scan: Scan,
+      runtimeFilters: Seq[Expression],
+      partitionPredicateSchema: => Option[Seq[PartitionPredicateField]],
+      output: Seq[AttributeReference],
+      outputPartitioning: Partitioning,
+      inputPartitions: => Seq[InputPartition]): Seq[Option[InputPartition]] = {
+    val filtered = pushRuntimeFilters(scan, runtimeFilters, 
partitionPredicateSchema, output)
+    if (filtered) {
+      // call toBatch again to get filtered partitions
+      val newPartitions = scan.toBatch.planInputPartitions()
+
+      outputPartitioning match {
+        case k: KeyedPartitioning =>
+          if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
+            throw new SparkException("Data source must have preserved the 
original partitioning " +
+                "during runtime filtering: not all partitions implement 
HasPartitionKey after " +
+                "filtering")
+          }
+
+          val inputMap = 
k.partitionKeys.groupBy(identity).view.mapValues(_.size)
+          val comparableKeyWrapperFactory = InternalRowComparableWrapper
+            .getInternalRowComparableWrapperFactory(k.expressionDataTypes)
+          val filteredMap = newPartitions.groupBy(
+            p => 
comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey())
+          )
+
+          if (!filteredMap.keySet.subsetOf(inputMap.keySet)) {
+            throw new SparkException("During runtime filtering, data source 
must not report new " +
+                "partition keys that are not present in the original 
partitioning.")
+          }
+
+          inputMap.toSeq
+            .sortBy(_._1)(k.keyOrdering)
+            .flatMap { case (key, size) =>
+              // We require the new number of partitions to be equal or less 
than the old number of
+              // partitions for a given key. In the case of less than, empty 
partitions are added.
+              val fps = filteredMap.getOrElse(key, Array.empty)
+
+              if (fps.size > size) {
+                throw new SparkException("During runtime filtering, data 
source must not report " +
+                  s"new partitions for a given key. Before: $size partitions. 
" +
+                  s"After: ${fps.size} partitions")
+              }
+
+              fps.map(Some).padTo(size, None)
+            }
+
+        case _ =>
+          // no validation is needed as the data source did not report any 
specific partitioning
+          newPartitions.toSeq.map(Some)
+      }
+
+    } else {
+      (outputPartitioning match {
+        case k: KeyedPartitioning =>
+          
inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyRowOrdering)

Review Comment:
   When `outputPartitioning` is `KeyedPartitioning`, this branch 
unconditionally casts each input partition to `HasPartitionKey`. For 
`BatchScanExec` this is safe because 
`DataSourceV2ScanExecBase.outputPartitioning` only produces `KeyedPartitioning` 
when every input partition already implements `HasPartitionKey`. The helper 
itself doesn't document or enforce this invariant, though — a future caller 
pairing a `KeyedPartitioning` with non-`HasPartitionKey` partitions would hit a 
cryptic `ClassCastException`. Worth a sentence in the Scaladoc, or an explicit 
precondition check.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to