This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e5ad5e94a8c8 [SPARK-48155][SQL] AQEPropagateEmptyRelation for join 
should check if remain child is just BroadcastQueryStageExec
e5ad5e94a8c8 is described below

commit e5ad5e94a8c891210637084a69359c1364201653
Author: Angerszhuuuu <angers....@gmail.com>
AuthorDate: Tue May 14 17:32:56 2024 +0800

    [SPARK-48155][SQL] AQEPropagateEmptyRelation for join should check if 
remain child is just BroadcastQueryStageExec
    
    ### What changes were proposed in this pull request?
    It's a new approach to fix 
[SPARK-39551](https://issues.apache.org/jira/browse/SPARK-39551)
    This situation happened for AQEPropagateEmptyRelation when one side is 
empty and one side is BroadcastQueryStateExec
    This pr avoid do propagate, not to revert all queryStagePreparationRules's 
result.
    
    ### Why are the changes needed?
    Fix bug
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Manuel tested `SPARK-39551: Invalid plan check - invalid broadcast query 
stage`, it can work well without origin fix and current pr
    
    For added UT,
    ```
      test("SPARK-48155: AQEPropagateEmptyRelation check remained child for 
join") {
        withSQLConf(
          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
          val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
            """
              |SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1
              |INNER JOIN (
              |  SELECT * FROM testData2
              |  WHERE b = 0
              |  UNION ALL
              |  SELECT * FROM testData2
              |  WHErE b != 0
              |) t2
              |ON t1.b = t2.b AND t1.a = 0
              |RIGHT OUTER JOIN testData2 t3
              |ON t1.a > t3.a
              |GROUP BY t3.b
            """.stripMargin
          )
          assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1)
          assert(findTopLevelUnion(adaptivePlan).size == 0)
        }
      }
    ```
    
    before this pr the adaptive plan is
    ```
    *(9) HashAggregate(keys=[b#226], functions=[count(1)], output=[b#226, 
count(a)#228L])
    +- AQEShuffleRead coalesced
       +- ShuffleQueryStage 3
          +- Exchange hashpartitioning(b#226, 5), ENSURE_REQUIREMENTS, 
[plan_id=356]
             +- *(8) HashAggregate(keys=[b#226], functions=[partial_count(1)], 
output=[b#226, count#232L])
                +- *(8) Project [b#226]
                   +- BroadcastNestedLoopJoin BuildRight, RightOuter, (a#23 > 
a#225)
                      :- *(7) Project [a#23]
                      :  +- *(7) SortMergeJoin [b#24], [b#220], Inner
                      :     :- *(5) Sort [b#24 ASC NULLS FIRST], false, 0
                      :     :  +- AQEShuffleRead coalesced
                      :     :     +- ShuffleQueryStage 0
                      :     :        +- Exchange hashpartitioning(b#24, 5), 
ENSURE_REQUIREMENTS, [plan_id=211]
                      :     :           +- *(1) Filter (a#23 = 0)
                      :     :              +- *(1) SerializeFromObject 
[knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#23, 
knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#24]
                      :     :                 +- Scan[obj#22]
                      :     +- *(6) Sort [b#220 ASC NULLS FIRST], false, 0
                      :        +- AQEShuffleRead coalesced
                      :           +- ShuffleQueryStage 1
                      :              +- Exchange hashpartitioning(b#220, 5), 
ENSURE_REQUIREMENTS, [plan_id=233]
                      :                 +- Union
                      :                    :- *(2) Project [b#220]
                      :                    :  +- *(2) Filter (b#220 = 0)
                      :                    :     +- *(2) SerializeFromObject 
[knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#219, 
knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#220]
                      :                    :        +- Scan[obj#218]
                      :                    +- *(3) Project [b#223]
                      :                       +- *(3) Filter NOT (b#223 = 0)
                      :                          +- *(3) SerializeFromObject 
[knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#222, 
knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#223]
                      :                             +- Scan[obj#221]
                      +- BroadcastQueryStage 2
                         +- BroadcastExchange IdentityBroadcastMode, 
[plan_id=260]
                            +- *(4) SerializeFromObject 
[knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#225, 
knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#226]
                               +- Scan[obj#224]
    
    ```
    
    After this patch
    ```
    *(6) HashAggregate(keys=[b#226], functions=[count(1)], output=[b#226, 
count(a)#228L])
    +- AQEShuffleRead coalesced
       +- ShuffleQueryStage 3
          +- Exchange hashpartitioning(b#226, 5), ENSURE_REQUIREMENTS, 
[plan_id=319]
             +- *(5) HashAggregate(keys=[b#226], functions=[partial_count(1)], 
output=[b#226, count#232L])
                +- *(5) Project [b#226]
                   +- BroadcastNestedLoopJoin BuildRight, RightOuter, (a#23 > 
a#225)
                      :- LocalTableScan <empty>, [a#23]
                      +- BroadcastQueryStage 2
                         +- BroadcastExchange IdentityBroadcastMode, 
[plan_id=260]
                            +- *(4) SerializeFromObject 
[knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#225, 
knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#226]
                               +- Scan[obj#224]
    [info] - xxxx (3 seconds, 136 milliseconds)
    
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46523 from AngersZhuuuu/SPARK-48155.
    
    Authored-by: Angerszhuuuu <angers....@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../optimizer/PropagateEmptyRelation.scala         | 13 ++++----
 .../adaptive/AQEPropagateEmptyRelation.scala       |  7 +++++
 .../adaptive/AdaptiveQueryExecSuite.scala          | 35 ++++++++++++++++++++++
 3 files changed, 50 insertions(+), 5 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index fd7a87087ddd..296274c61c18 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -65,6 +65,8 @@ abstract class PropagateEmptyRelationBase extends 
Rule[LogicalPlan] with CastSup
   private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] =
     plan.output.map{ a => Alias(cast(Literal(null), a.dataType), 
a.name)(a.exprId) }
 
+  protected def canExecuteWithoutJoin(plan: LogicalPlan): Boolean = true
+
   protected def commonApplyFunc: PartialFunction[LogicalPlan, LogicalPlan] = {
     case p: Union if p.children.exists(isEmpty) =>
       val newChildren = p.children.filterNot(isEmpty)
@@ -111,18 +113,19 @@ abstract class PropagateEmptyRelationBase extends 
Rule[LogicalPlan] with CastSup
           case LeftSemi if isRightEmpty | isFalseCondition => empty(p)
           case LeftAnti if isRightEmpty | isFalseCondition => p.left
           case FullOuter if isLeftEmpty && isRightEmpty => empty(p)
-          case LeftOuter | FullOuter if isRightEmpty =>
+          case LeftOuter | FullOuter if isRightEmpty && 
canExecuteWithoutJoin(p.left) =>
             Project(p.left.output ++ nullValueProjectList(p.right), p.left)
           case RightOuter if isRightEmpty => empty(p)
-          case RightOuter | FullOuter if isLeftEmpty =>
+          case RightOuter | FullOuter if isLeftEmpty && 
canExecuteWithoutJoin(p.right) =>
             Project(nullValueProjectList(p.left) ++ p.right.output, p.right)
-          case LeftOuter if isFalseCondition =>
+          case LeftOuter if isFalseCondition && canExecuteWithoutJoin(p.left) 
=>
             Project(p.left.output ++ nullValueProjectList(p.right), p.left)
-          case RightOuter if isFalseCondition =>
+          case RightOuter if isFalseCondition && 
canExecuteWithoutJoin(p.right) =>
             Project(nullValueProjectList(p.left) ++ p.right.output, p.right)
           case _ => p
         }
-      } else if (joinType == LeftSemi && conditionOpt.isEmpty && 
nonEmpty(p.right)) {
+      } else if (joinType == LeftSemi && conditionOpt.isEmpty &&
+        nonEmpty(p.right) && canExecuteWithoutJoin(p.left)) {
         p.left
       } else if (joinType == LeftAnti && conditionOpt.isEmpty && 
nonEmpty(p.right)) {
         empty(p)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
index 7951a6f36b9b..858130fae32b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
@@ -82,6 +82,13 @@ object AQEPropagateEmptyRelation extends 
PropagateEmptyRelationBase {
     case _ => false
   }
 
+  // A broadcast query stage can't be executed without the join operator.
+  // TODO: we can return the original query plan before broadcast.
+  override protected def canExecuteWithoutJoin(plan: LogicalPlan): Boolean = 
plan match {
+    case LogicalQueryStage(_, _: BroadcastQueryStageExec) => false
+    case _ => true
+  }
+
   override protected def applyInternal(p: LogicalPlan): LogicalPlan = 
p.transformUpWithPruning(
     // LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at
     // `PropagateEmptyRelationBase.commonApplyFunc`
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index e7b375e55f17..a7efd0aa75eb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -165,6 +165,12 @@ class AdaptiveQueryExecSuite
     }
   }
 
+  private def findTopLevelUnion(plan: SparkPlan): Seq[UnionExec] = {
+    collect(plan) {
+      case l: UnionExec => l
+    }
+  }
+
   private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = {
     collectWithSubqueries(plan) {
       case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _) => e
@@ -2795,6 +2801,35 @@ class AdaptiveQueryExecSuite
     }
   }
 
+  test("SPARK-48155: AQEPropagateEmptyRelation check remained child for join") 
{
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+      // Before SPARK-48155, since the AQE will call ValidateSparkPlan,
+      // all AQE optimize rule won't work and return the origin plan.
+      // After SPARK-48155, Spark avoid invalid propagate of empty relation.
+      // Then the UNION first child empty relation can be propagate correctly
+      // and the JOIN won't be propagated since will generated a invalid plan.
+      val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
+        """
+          |SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1
+          |INNER JOIN (
+          |  SELECT * FROM testData2
+          |  WHERE b = 0
+          |  UNION ALL
+          |  SELECT * FROM testData2
+          |  WHErE b != 0
+          |) t2
+          |ON t1.b = t2.b AND t1.a = 0
+          |RIGHT OUTER JOIN testData2 t3
+          |ON t1.a > t3.a
+          |GROUP BY t3.b
+        """.stripMargin
+      )
+      assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1)
+      assert(findTopLevelUnion(adaptivePlan).size == 0)
+    }
+  }
+
   test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
     withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "6") {
       // partitioning:  HashPartitioning


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to