This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e5ad5e94a8c8 [SPARK-48155][SQL] AQEPropagateEmptyRelation for join should check if remain child is just BroadcastQueryStageExec e5ad5e94a8c8 is described below commit e5ad5e94a8c891210637084a69359c1364201653 Author: Angerszhuuuu <angers....@gmail.com> AuthorDate: Tue May 14 17:32:56 2024 +0800 [SPARK-48155][SQL] AQEPropagateEmptyRelation for join should check if remain child is just BroadcastQueryStageExec ### What changes were proposed in this pull request? It's a new approach to fix [SPARK-39551](https://issues.apache.org/jira/browse/SPARK-39551) This situation happened for AQEPropagateEmptyRelation when one side is empty and one side is BroadcastQueryStateExec This pr avoid do propagate, not to revert all queryStagePreparationRules's result. ### Why are the changes needed? Fix bug ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manuel tested `SPARK-39551: Invalid plan check - invalid broadcast query stage`, it can work well without origin fix and current pr For added UT, ``` test("SPARK-48155: AQEPropagateEmptyRelation check remained child for join") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { val (_, adaptivePlan) = runAdaptiveAndVerifyResult( """ |SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1 |INNER JOIN ( | SELECT * FROM testData2 | WHERE b = 0 | UNION ALL | SELECT * FROM testData2 | WHErE b != 0 |) t2 |ON t1.b = t2.b AND t1.a = 0 |RIGHT OUTER JOIN testData2 t3 |ON t1.a > t3.a |GROUP BY t3.b """.stripMargin ) assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) assert(findTopLevelUnion(adaptivePlan).size == 0) } } ``` before this pr the adaptive plan is ``` *(9) HashAggregate(keys=[b#226], functions=[count(1)], output=[b#226, count(a)#228L]) +- AQEShuffleRead coalesced +- ShuffleQueryStage 3 +- Exchange hashpartitioning(b#226, 5), ENSURE_REQUIREMENTS, [plan_id=356] +- *(8) HashAggregate(keys=[b#226], functions=[partial_count(1)], output=[b#226, count#232L]) +- *(8) Project [b#226] +- BroadcastNestedLoopJoin BuildRight, RightOuter, (a#23 > a#225) :- *(7) Project [a#23] : +- *(7) SortMergeJoin [b#24], [b#220], Inner : :- *(5) Sort [b#24 ASC NULLS FIRST], false, 0 : : +- AQEShuffleRead coalesced : : +- ShuffleQueryStage 0 : : +- Exchange hashpartitioning(b#24, 5), ENSURE_REQUIREMENTS, [plan_id=211] : : +- *(1) Filter (a#23 = 0) : : +- *(1) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#23, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#24] : : +- Scan[obj#22] : +- *(6) Sort [b#220 ASC NULLS FIRST], false, 0 : +- AQEShuffleRead coalesced : +- ShuffleQueryStage 1 : +- Exchange hashpartitioning(b#220, 5), ENSURE_REQUIREMENTS, [plan_id=233] : +- Union : :- *(2) Project [b#220] : : +- *(2) Filter (b#220 = 0) : : +- *(2) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#219, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#220] : : +- Scan[obj#218] : +- *(3) Project [b#223] : +- *(3) Filter NOT (b#223 = 0) : +- *(3) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#222, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#223] : +- Scan[obj#221] +- BroadcastQueryStage 2 +- BroadcastExchange IdentityBroadcastMode, [plan_id=260] +- *(4) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#225, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#226] +- Scan[obj#224] ``` After this patch ``` *(6) HashAggregate(keys=[b#226], functions=[count(1)], output=[b#226, count(a)#228L]) +- AQEShuffleRead coalesced +- ShuffleQueryStage 3 +- Exchange hashpartitioning(b#226, 5), ENSURE_REQUIREMENTS, [plan_id=319] +- *(5) HashAggregate(keys=[b#226], functions=[partial_count(1)], output=[b#226, count#232L]) +- *(5) Project [b#226] +- BroadcastNestedLoopJoin BuildRight, RightOuter, (a#23 > a#225) :- LocalTableScan <empty>, [a#23] +- BroadcastQueryStage 2 +- BroadcastExchange IdentityBroadcastMode, [plan_id=260] +- *(4) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).a AS a#225, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData2, true])).b AS b#226] +- Scan[obj#224] [info] - xxxx (3 seconds, 136 milliseconds) ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #46523 from AngersZhuuuu/SPARK-48155. Authored-by: Angerszhuuuu <angers....@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../optimizer/PropagateEmptyRelation.scala | 13 ++++---- .../adaptive/AQEPropagateEmptyRelation.scala | 7 +++++ .../adaptive/AdaptiveQueryExecSuite.scala | 35 ++++++++++++++++++++++ 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index fd7a87087ddd..296274c61c18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -65,6 +65,8 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] = plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) } + protected def canExecuteWithoutJoin(plan: LogicalPlan): Boolean = true + protected def commonApplyFunc: PartialFunction[LogicalPlan, LogicalPlan] = { case p: Union if p.children.exists(isEmpty) => val newChildren = p.children.filterNot(isEmpty) @@ -111,18 +113,19 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup case LeftSemi if isRightEmpty | isFalseCondition => empty(p) case LeftAnti if isRightEmpty | isFalseCondition => p.left case FullOuter if isLeftEmpty && isRightEmpty => empty(p) - case LeftOuter | FullOuter if isRightEmpty => + case LeftOuter | FullOuter if isRightEmpty && canExecuteWithoutJoin(p.left) => Project(p.left.output ++ nullValueProjectList(p.right), p.left) case RightOuter if isRightEmpty => empty(p) - case RightOuter | FullOuter if isLeftEmpty => + case RightOuter | FullOuter if isLeftEmpty && canExecuteWithoutJoin(p.right) => Project(nullValueProjectList(p.left) ++ p.right.output, p.right) - case LeftOuter if isFalseCondition => + case LeftOuter if isFalseCondition && canExecuteWithoutJoin(p.left) => Project(p.left.output ++ nullValueProjectList(p.right), p.left) - case RightOuter if isFalseCondition => + case RightOuter if isFalseCondition && canExecuteWithoutJoin(p.right) => Project(nullValueProjectList(p.left) ++ p.right.output, p.right) case _ => p } - } else if (joinType == LeftSemi && conditionOpt.isEmpty && nonEmpty(p.right)) { + } else if (joinType == LeftSemi && conditionOpt.isEmpty && + nonEmpty(p.right) && canExecuteWithoutJoin(p.left)) { p.left } else if (joinType == LeftAnti && conditionOpt.isEmpty && nonEmpty(p.right)) { empty(p) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala index 7951a6f36b9b..858130fae32b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala @@ -82,6 +82,13 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { case _ => false } + // A broadcast query stage can't be executed without the join operator. + // TODO: we can return the original query plan before broadcast. + override protected def canExecuteWithoutJoin(plan: LogicalPlan): Boolean = plan match { + case LogicalQueryStage(_, _: BroadcastQueryStageExec) => false + case _ => true + } + override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning( // LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at // `PropagateEmptyRelationBase.commonApplyFunc` diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index e7b375e55f17..a7efd0aa75eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -165,6 +165,12 @@ class AdaptiveQueryExecSuite } } + private def findTopLevelUnion(plan: SparkPlan): Seq[UnionExec] = { + collect(plan) { + case l: UnionExec => l + } + } + private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = { collectWithSubqueries(plan) { case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _) => e @@ -2795,6 +2801,35 @@ class AdaptiveQueryExecSuite } } + test("SPARK-48155: AQEPropagateEmptyRelation check remained child for join") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + // Before SPARK-48155, since the AQE will call ValidateSparkPlan, + // all AQE optimize rule won't work and return the origin plan. + // After SPARK-48155, Spark avoid invalid propagate of empty relation. + // Then the UNION first child empty relation can be propagate correctly + // and the JOIN won't be propagated since will generated a invalid plan. + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1 + |INNER JOIN ( + | SELECT * FROM testData2 + | WHERE b = 0 + | UNION ALL + | SELECT * FROM testData2 + | WHErE b != 0 + |) t2 + |ON t1.b = t2.b AND t1.a = 0 + |RIGHT OUTER JOIN testData2 t3 + |ON t1.a > t3.a + |GROUP BY t3.b + """.stripMargin + ) + assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) + assert(findTopLevelUnion(adaptivePlan).size == 0) + } + } + test("SPARK-39915: Dataset.repartition(N) may not create N partitions") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "6") { // partitioning: HashPartitioning --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org