This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 103de914a5f [SPARK-44649][SQL] Runtime Filter supports passing 
equivalent creation side expressions
103de914a5f is described below

commit 103de914a5f96fccbe722663ee69c8ee7d9c8135
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Wed Oct 18 14:55:51 2023 +0800

    [SPARK-44649][SQL] Runtime Filter supports passing equivalent creation side 
expressions
    
    ### What changes were proposed in this pull request?
    Currently, Spark runtime filter supports multi level shuffle join side as 
filter creation side. Please see: https://github.com/apache/spark/pull/39170. 
Although this feature adds the adaptive scene and improves the performance, 
there are still need to support other case.
    
    **Optimization of Expression Transitivity on the Creation Side of Spark 
Runtime Filter**
    
    **Principle**
    Association expressions are transitive in some Joins, such as:
    `Tab1.col1A = Tab2.col2B` and `Tab2.col2B = Tab3.col3C`
    Actually, it can be inferred that `Tab1.col1A = Tab3.col3C`.
    
    **Optimization points**
    Currently, the runtime filter's creation side expression only uses directly 
associated keys. If the transitivity of association conditions is utilized, 
runtime filters can be injected into many scenarios, such as:
    
    ```
    SELECT *
    FROM (
      SELECT *
      FROM tab1
        JOIN tab2 ON tab1.c1 = tab2.c2
      WHERE tab2.a2 = 5
    ) AS a
      JOIN tab3 ON tab3.c3 = a.c1
    ```
    
    The `tab3.c3` here is only associated with `tab1.c1` and not with 
`tab2.c2`. Although there is selective filtering on tab2 (`tab2.a2 = 5`), Spark 
is currently unable to inject a Runtime Filter.
    As long as transitivity is considered, we can know that `tab3.c3` and 
`tab2.c2` are related, so we can still inject Runtime Filter and improve 
performance.
    
    For the current implementation, Spark only inject runtime filter into tab1 
with bloom filter based on `bf2.a2 = 5`.
    Because there is no the join between tab3 and tab2, so Spark can't inject 
runtime filter into tab3 with the same bloom filter.
    But the above SQL have the join condition `tab3.c3 = a.c1(tab1.c1)` between 
tab3 and tab2, and also have the join condition `tab1.c1 = tab2.c2`. We can 
rely on the transitivity of the join condition to get the virtual join 
condition `tab3.c3 = tab2.c2`, then we can inject the bloom filter based on 
`bf2.a2 = 5` into tab3.
    
    ### Why are the changes needed?
    Enhance the Spark runtime filter and improve performance.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    Just update the inner implementation.
    
    ### How was this patch tested?
    New tests.
    Micro benchmark for q75 in TPC-DS.
    **2TB TPC-DS**
    | TPC-DS Query   | Before(Seconds)  | After(Seconds)  | Speedup(Percent)  |
    |  ----  | ----  | ----  | ----  |
    | q75 | 129.664 | 81.562 | 58.98% |
    
    Closes #42317 from beliefer/SPARK-44649.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../catalyst/optimizer/InjectRuntimeFilter.scala   | 64 +++++++++++++++-------
 .../spark/sql/InjectRuntimeFilterSuite.scala       | 38 +++++++++++--
 2 files changed, 75 insertions(+), 27 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index 8737082e571..30526bd8106 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -125,14 +125,14 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with 
PredicateHelper with J
    */
   private def extractSelectiveFilterOverScan(
       plan: LogicalPlan,
-      filterCreationSideKey: Expression): Option[LogicalPlan] = {
-    @tailrec
+      filterCreationSideKey: Expression): Option[(Expression, LogicalPlan)] = {
     def extract(
         p: LogicalPlan,
         predicateReference: AttributeSet,
         hasHitFilter: Boolean,
         hasHitSelectiveFilter: Boolean,
-        currentPlan: LogicalPlan): Option[LogicalPlan] = p match {
+        currentPlan: LogicalPlan,
+        targetKey: Expression): Option[(Expression, LogicalPlan)] = p match {
       case Project(projectList, child) if hasHitFilter =>
         // We need to make sure all expressions referenced by filter 
predicates are simple
         // expressions.
@@ -143,41 +143,62 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with 
PredicateHelper with J
             referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ 
++ _),
             hasHitFilter,
             hasHitSelectiveFilter,
-            currentPlan)
+            currentPlan,
+            targetKey)
         } else {
           None
         }
       case Project(_, child) =>
         assert(predicateReference.isEmpty && !hasHitSelectiveFilter)
-        extract(child, predicateReference, hasHitFilter, 
hasHitSelectiveFilter, currentPlan)
+        extract(child, predicateReference, hasHitFilter, 
hasHitSelectiveFilter, currentPlan,
+          targetKey)
       case Filter(condition, child) if isSimpleExpression(condition) =>
         extract(
           child,
           predicateReference ++ condition.references,
           hasHitFilter = true,
           hasHitSelectiveFilter = hasHitSelectiveFilter || 
isLikelySelective(condition),
-          currentPlan)
-      case ExtractEquiJoinKeys(_, _, _, _, _, left, right, _) =>
+          currentPlan,
+          targetKey)
+      case ExtractEquiJoinKeys(_, lkeys, rkeys, _, _, left, right, _) =>
         // Runtime filters use one side of the [[Join]] to build a set of join 
key values and prune
         // the other side of the [[Join]]. It's also OK to use a superset of 
the join key values
         // (ignore null values) to do the pruning.
-        if (left.output.exists(_.semanticEquals(filterCreationSideKey))) {
-          extract(left, AttributeSet.empty,
-            hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = 
left)
-        } else if 
(right.output.exists(_.semanticEquals(filterCreationSideKey))) {
-          extract(right, AttributeSet.empty,
-            hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = 
right)
+        // We assume other rules have already pushed predicates through join 
if possible.
+        // So the predicate references won't pass on anymore.
+        if (left.output.exists(_.semanticEquals(targetKey))) {
+          extract(left, AttributeSet.empty, hasHitFilter = false, 
hasHitSelectiveFilter = false,
+            currentPlan = left, targetKey = targetKey).orElse {
+            // We can also extract from the right side if the join keys are 
transitive.
+            lkeys.zip(rkeys).find(_._1.semanticEquals(targetKey)).map(_._2)
+              .flatMap { newTargetKey =>
+                extract(right, AttributeSet.empty,
+                  hasHitFilter = false, hasHitSelectiveFilter = false, 
currentPlan = right,
+                  targetKey = newTargetKey)
+              }
+          }
+        } else if (right.output.exists(_.semanticEquals(targetKey))) {
+          extract(right, AttributeSet.empty, hasHitFilter = false, 
hasHitSelectiveFilter = false,
+            currentPlan = right, targetKey = targetKey).orElse {
+            // We can also extract from the left side if the join keys are 
transitive.
+            rkeys.zip(lkeys).find(_._1.semanticEquals(targetKey)).map(_._2)
+              .flatMap { newTargetKey =>
+                extract(left, AttributeSet.empty,
+                  hasHitFilter = false, hasHitSelectiveFilter = false, 
currentPlan = left,
+                  targetKey = newTargetKey)
+              }
+          }
         } else {
           None
         }
       case _: LeafNode if hasHitSelectiveFilter =>
-        Some(currentPlan)
+        Some((targetKey, currentPlan))
       case _ => None
     }
 
     if (!plan.isStreaming) {
-      extract(plan, AttributeSet.empty,
-        hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = 
plan)
+      extract(plan, AttributeSet.empty, hasHitFilter = false, 
hasHitSelectiveFilter = false,
+        currentPlan = plan, targetKey = filterCreationSideKey)
     } else {
       None
     }
@@ -239,7 +260,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with 
PredicateHelper with J
       filterApplicationSide: LogicalPlan,
       filterCreationSide: LogicalPlan,
       filterApplicationSideKey: Expression,
-      filterCreationSideKey: Expression): Option[LogicalPlan] = {
+      filterCreationSideKey: Expression): Option[(Expression, LogicalPlan)] = {
     if (findExpressionAndTrackLineageDown(
       filterApplicationSideKey, filterApplicationSide).isDefined &&
       satisfyByteSizeRequirement(filterApplicationSide)) {
@@ -331,8 +352,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with 
PredicateHelper with J
             val hasShuffle = isProbablyShuffleJoin(left, right, hint)
             if (canPruneLeft(joinType) && (hasShuffle || 
probablyHasShuffle(left))) {
               extractBeneficialFilterCreatePlan(left, right, l, r).foreach {
-                filterCreationSidePlan =>
-                  newLeft = injectFilter(l, newLeft, r, filterCreationSidePlan)
+                case (filterCreationSideKey, filterCreationSidePlan) =>
+                  newLeft = injectFilter(l, newLeft, filterCreationSideKey, 
filterCreationSidePlan)
               }
             }
             // Did we actually inject on the left? If not, try on the right
@@ -341,8 +362,9 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with 
PredicateHelper with J
             if (newLeft.fastEquals(oldLeft) && canPruneRight(joinType) &&
               (hasShuffle || probablyHasShuffle(right))) {
               extractBeneficialFilterCreatePlan(right, left, r, l).foreach {
-                filterCreationSidePlan =>
-                  newRight = injectFilter(r, newRight, l, 
filterCreationSidePlan)
+                case (filterCreationSideKey, filterCreationSidePlan) =>
+                  newRight = injectFilter(
+                    r, newRight, filterCreationSideKey, filterCreationSidePlan)
               }
             }
             if (!newLeft.fastEquals(oldLeft) || 
!newRight.fastEquals(oldRight)) {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
index fedfd9ff587..c46e0bfcecb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
@@ -360,7 +360,7 @@ class InjectRuntimeFilterSuite extends QueryTest with 
SQLTestUtils with SharedSp
     
withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key
 -> "3000",
       SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
       assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 join bf4 on " 
+
-        "bf1.c1 = bf2.c2 and bf2.c2 = bf3.c3 and bf3.c3 = bf4.c4 where bf1.a1 
= 5")
+        "bf1.c1 = bf2.c2 and bf2.c2 = bf3.c3 and bf3.c3 = bf4.c4 where bf1.a1 
= 5", 3)
     }
   }
 
@@ -390,34 +390,60 @@ class InjectRuntimeFilterSuite extends QueryTest with 
SQLTestUtils with SharedSp
   test("Runtime bloom filter join: two joins") {
     
withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key
 -> "3000",
       SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+      // bf2 as creation side and inject runtime filter for bf1 and bf3.
       assertRewroteWithBloomFilter("select * from bf1 join bf2 join bf3 on 
bf1.c1 = bf2.c2 " +
         "and bf3.c3 = bf2.c2 where bf2.a2 = 5", 2)
-      assertRewroteWithBloomFilter("select * from (select * from bf1 left semi 
join bf2 on " +
-        "bf1.c1 = bf2.c2 where bf1.a1 = 5) as a join bf3 on bf3.c3 = a.c1")
-      assertRewroteWithBloomFilter("select * from (select * from bf1 left anti 
join bf2 on " +
-        "bf1.c1 = bf2.c2 where bf1.a1 = 5) as a join bf3 on bf3.c3 = a.c1")
       assertRewroteWithBloomFilter("select * from bf1 left outer join bf2 join 
bf3 on " +
         "bf1.c1 = bf2.c2 and bf3.c3 = bf2.c2 where bf2.a2 = 5", 2)
       assertRewroteWithBloomFilter("select * from bf1 right outer join bf2 
join bf3 on " +
         "bf1.c1 = bf2.c2 and bf3.c3 = bf2.c2 where bf2.a2 = 5", 2)
+      // bf1 and bf2 hasn't shuffle. bf1 as creation side and inject runtime 
filter for bf3.
+      assertRewroteWithBloomFilter("select * from (select * from bf1 left semi 
join bf2 on " +
+        "bf1.c1 = bf2.c2 where bf1.a1 = 5) as a join bf3 on bf3.c3 = a.c1")
+      assertRewroteWithBloomFilter("select * from (select * from bf1 left anti 
join bf2 on " +
+        "bf1.c1 = bf2.c2 where bf1.a1 = 5) as a join bf3 on bf3.c3 = a.c1")
+      // bf1 as creation side and inject runtime filter for bf2 and bf3.
       assertRewroteWithBloomFilter("select * from bf1 join bf2 join bf3 on 
bf1.c1 = bf2.c2 " +
         "and bf3.c3 = bf1.c1 where bf1.a1 = 5", 2)
       assertRewroteWithBloomFilter("select * from bf1 left outer join bf2 join 
bf3 on " +
         "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf1.a1 = 5", 2)
       assertRewroteWithBloomFilter("select * from bf1 right outer join bf2 
join bf3 on " +
         "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf1.a1 = 5", 2)
+      // bf2 as creation side and inject runtime filter for bf1 and bf3(join 
keys are transitive).
+      assertRewroteWithBloomFilter("select * from (select * from bf1 join bf2 
on " +
+        "bf1.c1 = bf2.c2 where bf2.a2 = 5) as a join bf3 on bf3.c3 = a.c1", 2)
+      assertRewroteWithBloomFilter("select * from (select * from bf1 left join 
bf2 on " +
+        "bf1.c1 = bf2.c2 where bf2.a2 = 5) as a join bf3 on bf3.c3 = a.c1", 2)
+      assertRewroteWithBloomFilter("select * from (select * from bf1 right 
join bf2 on " +
+        "bf1.c1 = bf2.c2 where bf2.a2 = 5) as a join bf3 on bf3.c3 = a.c1", 2)
+      // Can't leverage the transitivity of join keys due to runtime filters 
already exists.
+      // bf2 as creation side and inject runtime filter for bf1.
+      assertRewroteWithBloomFilter("select * from bf1 join bf2 join bf3 on 
bf1.c1 = bf2.c2 " +
+        "and bf3.c3 = bf1.c1 where bf2.a2 = 5")
+      assertRewroteWithBloomFilter("select * from bf1 left outer join bf2 join 
bf3 on " +
+        "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5")
+      assertRewroteWithBloomFilter("select * from bf1 right outer join bf2 
join bf3 on " +
+        "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5")
     }
 
     
withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key
 -> "3000",
       SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1200") {
+      // bf1 as creation side and inject runtime filter for bf2 and bf3.
       assertRewroteWithBloomFilter("select * from (select * from bf1 left semi 
join bf2 on " +
         "bf1.c1 = bf2.c2 where bf1.a1 = 5) as a join bf3 on bf3.c3 = a.c1", 2)
+      // left anti join unsupported. bf1 as creation side and inject runtime 
filter for bf3.
       assertRewroteWithBloomFilter("select * from (select * from bf1 left anti 
join bf2 on " +
         "bf1.c1 = bf2.c2 where bf1.a1 = 5) as a join bf3 on bf3.c3 = a.c1")
+      // bf2 as creation side and inject runtime filter for bf1 and bf3(by 
passing key).
       assertRewroteWithBloomFilter("select * from (select * from bf1 left semi 
join bf2 on " +
+        "(bf1.c1 = bf2.c2 and bf2.a2 = 5)) as a join bf3 on bf3.c3 = a.c1", 2)
+      // left anti join unsupported.
+      // bf2 as creation side and inject runtime filter for bf3(by passing 
key).
+      assertRewroteWithBloomFilter("select * from (select * from bf1 left anti 
join bf2 on " +
         "(bf1.c1 = bf2.c2 and bf2.a2 = 5)) as a join bf3 on bf3.c3 = a.c1")
+      // left anti join unsupported and hasn't selective filter.
       assertRewroteWithBloomFilter("select * from (select * from bf1 left anti 
join bf2 on " +
-        "(bf1.c1 = bf2.c2 and bf2.a2 = 5)) as a join bf3 on bf3.c3 = a.c1", 0)
+        "(bf1.c1 = bf2.c2 and bf1.a1 = 5)) as a join bf3 on bf3.c3 = a.c1", 0)
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to