This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1e8048a8810f [SPARK-54865][CONNECT][SQL] Add 
foreachWithSubqueriesAndPruning method to QueryPlan
1e8048a8810f is described below

commit 1e8048a8810f81afdb237116d4fa4be369f65398
Author: Yihong He <[email protected]>
AuthorDate: Thu Jan 8 03:09:47 2026 +0800

    [SPARK-54865][CONNECT][SQL] Add foreachWithSubqueriesAndPruning method to 
QueryPlan
    
    ### What changes were proposed in this pull request?
    
    This PR introduces a new method foreachWithSubqueriesAndPruning in 
QueryPlan.scala that provides a pruning-enabled variant of 
foreachWithSubqueries. The method only traverses nodes that match a given 
condition, improving efficiency. The PR also updates two existing usages:
    1. SparkConnectPlanner - Changed from transformUpWithSubqueriesAndPruning 
to foreachWithSubqueriesAndPruning since the code was only collecting 
observations without transforming the plan
    2. ObservationManager - Changed from foreach to 
foreachWithSubqueriesAndPruning with a condition to only visit nodes containing 
COLLECT_METRICS pattern
    
    ### Why are the changes needed?
    
    The changes are needed to:
    1. Provide a more efficient way to traverse query plans when only specific 
nodes matching certain patterns need to be visited (avoiding unnecessary 
traversal of irrelevant subtrees)
    2. Optimize observation management in ObservationManager by only traversing 
nodes that contain COLLECT_METRICS pattern instead of visiting every node
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. This is an internal optimization that improves performance and code 
correctness without changing any user-facing behavior or APIs.
    
    ### How was this patch tested?
    
    `build/sbt "catalyst/testOnly 
org.apache.spark.sql.catalyst.plans.QueryPlanSuite"`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Cursor 2.2.44
    
    Closes #53637 from heyihong/SPARK-54865.
    
    Authored-by: Yihong He <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/plans/QueryPlan.scala       | 14 +++++++++++
 .../spark/sql/catalyst/plans/QueryPlanSuite.scala  | 27 +++++++++++++++++++++-
 .../sql/connect/planner/SparkConnectPlanner.scala  |  7 +++---
 .../spark/sql/classic/ObservationManager.scala     |  4 +++-
 4 files changed, 47 insertions(+), 5 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index ddb6e3349d80..6a085d714ddf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -628,6 +628,20 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
     children.foreach(_.foreachWithSubqueries(f))
   }
 
+  /**
+   * A variant of [[foreachWithSubqueries]] with pruning support.
+   * Only traverses nodes that match the given condition.
+   */
+  def foreachWithSubqueriesAndPruning(
+      cond: TreePatternBits => Boolean)(f: PlanType => Unit): Unit = {
+    if (!cond.apply(this)) {
+      return
+    }
+    f(this)
+    subqueries.foreach(_.foreachWithSubqueriesAndPruning(cond)(f))
+    children.foreach(_.foreachWithSubqueriesAndPruning(cond)(f))
+  }
+
   /**
    * A variant of `collect`. This method not only apply the given function to 
all elements in this
    * plan, also considering all the plans in its (nested) subqueries.
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala
index 03ed466e2b03..91f990be7bb2 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.plans
 
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
@@ -25,7 +27,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, 
Expression, ListQuery, Literal, NamedExpression, Rand}
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, 
LogicalPlan, Project, Union}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
+import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern}
 import org.apache.spark.sql.types.IntegerType
 
 class QueryPlanSuite extends SparkFunSuite {
@@ -160,4 +162,27 @@ class QueryPlanSuite extends SparkFunSuite {
     val planAfterTestRule = testRule(plan)
     assert(planAfterTestRule.output(0).nullable)
   }
+
+  test("SPARK-54865: pruning works correctly in 
foreachWithSubqueriesAndPruning") {
+    val a: NamedExpression = AttributeReference("a", IntegerType)()
+    val plan = Project(
+      Seq(a),
+      Filter(
+        ListQuery(Project(
+          Seq(a),
+          UnresolvedRelation(TableIdentifier("t", None))
+        )),
+        UnresolvedRelation(TableIdentifier("t", None))
+      )
+    )
+
+    val visited = ArrayBuffer[LogicalPlan]()
+    
plan.foreachWithSubqueriesAndPruning(_.containsPattern(TreePattern.FILTER)) { p 
=>
+      visited += p
+    }
+
+    // Only 2 nodes contain FILTER pattern: outer Project and Filter
+    assert(visited.size == 2)
+    assert(visited.forall(_.containsPattern(TreePattern.FILTER)))
+  }
 }
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 0d95fe31e063..b2c32df4d863 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -242,7 +242,7 @@ class SparkConnectPlanner(
     }
 
     if (executeHolderOpt.isDefined) {
-      
plan.transformUpWithSubqueriesAndPruning(_.containsPattern(TreePattern.COLLECT_METRICS))
 {
+      
plan.foreachWithSubqueriesAndPruning(_.containsPattern(TreePattern.COLLECT_METRICS))
 {
         case collectMetrics: CollectMetrics if 
!collectMetrics.child.isStreaming =>
           // TODO this might be too complex for no good reason. It might
           //  be easier to inspect the plan after it completes.
@@ -250,9 +250,10 @@ class SparkConnectPlanner(
             collectMetrics.name,
             collectMetrics.dataframeId)
           executeHolder.addObservation(collectMetrics.name, observation)
-          collectMetrics
+        case _ =>
       }
-    } else plan
+    }
+    plan
   }
 
   private def transformRelationPlugin(extension: ProtoAny): LogicalPlan = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala
index 308651b449fd..b5ec18d5ff12 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala
@@ -20,6 +20,7 @@ import java.util.concurrent.ConcurrentHashMap
 
 import org.apache.spark.sql.{Observation, Row}
 import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics
+import org.apache.spark.sql.catalyst.trees.TreePattern
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.util.QueryExecutionListener
 
@@ -54,7 +55,8 @@ private[sql] class ObservationManager(session: SparkSession) {
 
   private def tryComplete(qe: QueryExecution): Unit = {
     val allMetrics = qe.observedMetrics
-    qe.logical.foreach {
+    qe.logical.foreachWithSubqueriesAndPruning(
+      _.containsPattern(TreePattern.COLLECT_METRICS)) {
       case c: CollectMetrics =>
         val keyExists = observations.containsKey((c.name, c.dataframeId))
         val metrics = allMetrics.get(c.name)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to