This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.2 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push: new 7eca60d4f30 [SPARK-41162][SQL][3.3] Fix anti- and semi-join for self-join with aggregations 7eca60d4f30 is described below commit 7eca60d4f304d4a1a66add9fd04166d8eed1dd4f Author: Enrico Minack <git...@enrico.minack.dev> AuthorDate: Fri Jan 6 11:32:45 2023 +0800 [SPARK-41162][SQL][3.3] Fix anti- and semi-join for self-join with aggregations ### What changes were proposed in this pull request? Backport #39131 to branch-3.3. Rule `PushDownLeftSemiAntiJoin` should not push an anti-join below an `Aggregate` when the join condition references an attribute that exists in its right plan and its left plan's child. This usually happens when the anti-join / semi-join is a self-join while `DeduplicateRelations` cannot deduplicate those attributes (in this example due to the projection of `value` to `id`). This behaviour already exists for `Project` and `Union`, but `Aggregate` lacks this safety guard. ### Why are the changes needed? Without this change, the optimizer creates an incorrect plan. This example fails with `distinct()` (an aggregation), and succeeds without `distinct()`, but both queries are identical: ```scala val ids = Seq(1, 2, 3).toDF("id").distinct() val result = ids.withColumn("id", $"id" + 1).join(ids, Seq("id"), "left_anti").collect() assert(result.length == 1) ``` With `distinct()`, rule `PushDownLeftSemiAntiJoin` creates a join condition `(value#907 + 1) = value#907`, which can never be true. This effectively removes the anti-join. **Before this PR:** The anti-join is fully removed from the plan. ``` == Physical Plan == AdaptiveSparkPlan (16) +- == Final Plan == LocalTableScan (1) (16) AdaptiveSparkPlan Output [1]: [id#900] Arguments: isFinalPlan=true ``` This is caused by `PushDownLeftSemiAntiJoin` adding join condition `(value#907 + 1) = value#907`, which is wrong as because `id#910` in `(id#910 + 1) AS id#912` exists in the right child of the join as well as in the left grandchild: ``` === Applying Rule org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin === !Join LeftAnti, (id#912 = id#910) Aggregate [id#910], [(id#910 + 1) AS id#912] !:- Aggregate [id#910], [(id#910 + 1) AS id#912] +- Project [value#907 AS id#910] !: +- Project [value#907 AS id#910] +- Join LeftAnti, ((value#907 + 1) = value#907) !: +- LocalRelation [value#907] :- LocalRelation [value#907] !+- Aggregate [id#910], [id#910] +- Aggregate [id#910], [id#910] ! +- Project [value#914 AS id#910] +- Project [value#914 AS id#910] ! +- LocalRelation [value#914] +- LocalRelation [value#914] ``` The right child of the join and in the left grandchild would become the children of the pushed-down join, which creates an invalid join condition. **After this PR:** Join condition `(id#910 + 1) AS id#912` is understood to become ambiguous as both sides of the prospect join contain `id#910`. Hence, the join is not pushed down. The rule is then not applied any more. The final plan contains the anti-join: ``` == Physical Plan == AdaptiveSparkPlan (24) +- == Final Plan == * BroadcastHashJoin LeftSemi BuildRight (14) :- * HashAggregate (7) : +- AQEShuffleRead (6) : +- ShuffleQueryStage (5), Statistics(sizeInBytes=48.0 B, rowCount=3) : +- Exchange (4) : +- * HashAggregate (3) : +- * Project (2) : +- * LocalTableScan (1) +- BroadcastQueryStage (13), Statistics(sizeInBytes=1024.0 KiB, rowCount=3) +- BroadcastExchange (12) +- * HashAggregate (11) +- AQEShuffleRead (10) +- ShuffleQueryStage (9), Statistics(sizeInBytes=48.0 B, rowCount=3) +- ReusedExchange (8) (8) ReusedExchange [Reuses operator id: 4] Output [1]: [id#898] (24) AdaptiveSparkPlan Output [1]: [id#900] Arguments: isFinalPlan=true ``` ### Does this PR introduce _any_ user-facing change? It fixes correctness. ### How was this patch tested? Unit tests in `DataFrameJoinSuite` and `LeftSemiAntiJoinPushDownSuite`. Closes #39409 from EnricoMi/branch-antijoin-selfjoin-fix-3.3. Authored-by: Enrico Minack <git...@enrico.minack.dev> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit b97f79da04acc9bde1cb4def7dc33c22cfc11372) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../optimizer/PushDownLeftSemiAntiJoin.scala | 13 ++--- .../optimizer/LeftSemiAntiJoinPushDownSuite.scala | 57 ++++++++++++++-------- .../org/apache/spark/sql/DataFrameJoinSuite.scala | 18 +++++++ 3 files changed, 63 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index 31b9d604060..8a146c4d688 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -56,9 +56,10 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] } // LeftSemi/LeftAnti over Aggregate, only push down if join can be planned as broadcast join. - case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _) + case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), joinCond, _) if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty && !agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) && + canPushThroughCondition(agg.children, joinCond, rightOp) && canPlanAsBroadcastHashJoin(join, conf) => val aliasMap = getAliasMap(agg) val canPushDownPredicate = (predicate: Expression) => { @@ -105,11 +106,11 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] } /** - * Check if we can safely push a join through a project or union by making sure that attributes - * referred in join condition do not contain the same attributes as the plan they are moved - * into. This can happen when both sides of join refers to the same source (self join). This - * function makes sure that the join condition refers to attributes that are not ambiguous (i.e - * present in both the legs of the join) or else the resultant plan will be invalid. + * Check if we can safely push a join through a project, aggregate, or union by making sure that + * attributes referred in join condition do not contain the same attributes as the plan they are + * moved into. This can happen when both sides of join refers to the same source (self join). + * This function makes sure that the join condition refers to attributes that are not ambiguous + * (i.e present in both the legs of the join) or else the resultant plan will be invalid. */ private def canPushThroughCondition( plans: Seq[LogicalPlan], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index 88c29c9274a..0b5a7f76607 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType -class LeftSemiPushdownSuite extends PlanTest { +class LeftSemiAntiJoinPushDownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -46,7 +46,7 @@ class LeftSemiPushdownSuite extends PlanTest { val testRelation1 = LocalRelation('d.int) val testRelation2 = LocalRelation('e.int) - test("Project: LeftSemiAnti join pushdown") { + test("Project: LeftSemi join pushdown") { val originalQuery = testRelation .select(star()) .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) @@ -59,7 +59,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") { + test("Project: LeftSemi join no pushdown - non-deterministic proj exprs") { val originalQuery = testRelation .select(Rand(1), 'b, 'c) .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) @@ -68,7 +68,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("Project: LeftSemiAnti join non correlated scalar subq") { + test("Project: LeftSemi join pushdown - non-correlated scalar subq") { val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze) val originalQuery = testRelation .select(subq.as("sum")) @@ -83,7 +83,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in projection list") { + test("Project: LeftSemi join no pushdown - correlated scalar subq in projection list") { val testRelation2 = LocalRelation('e.int, 'f.int) val subqPlan = testRelation2.groupBy('e)(sum('f).as("sum")).where('e === 'a) val subqExpr = ScalarSubquery(subqPlan) @@ -95,7 +95,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("Aggregate: LeftSemiAnti join pushdown") { + test("Aggregate: LeftSemi join pushdown") { val originalQuery = testRelation .groupBy('b)('b, sum('c)) .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) @@ -109,7 +109,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr expressions") { + test("Aggregate: LeftSemi join no pushdown - non-deterministic aggr expressions") { val originalQuery = testRelation .groupBy('b)('b, Rand(10).as('c)) .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) @@ -142,7 +142,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("LeftSemiAnti join over aggregate - no pushdown") { + test("Aggregate: LeftSemi join no pushdown") { val originalQuery = testRelation .groupBy('b)('b, sum('c).as('sum)) .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd)) @@ -151,7 +151,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") { + test("Aggregate: LeftSemi join pushdown - non-correlated scalar subq aggr exprs") { val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze) val originalQuery = testRelation .groupBy('a) ('a, subq.as("sum")) @@ -166,7 +166,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("LeftSemiAnti join over Window") { + test("Window: LeftSemi join pushdown") { val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) val originalQuery = testRelation @@ -184,7 +184,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Window: LeftSemi partial pushdown") { + test("Window: LeftSemi join partial pushdown") { // Attributes from join condition which does not refer to the window partition spec // are kept up in the plan as a Filter operator above Window. val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) @@ -224,7 +224,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Union: LeftSemiAnti join pushdown") { + test("Union: LeftSemi join pushdown") { val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int) val originalQuery = Union(Seq(testRelation, testRelation2)) @@ -240,7 +240,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Union: LeftSemiAnti join pushdown in self join scenario") { + test("Union: LeftSemi join pushdown in self join scenario") { val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int) val attrX = testRelation2.output.head @@ -259,7 +259,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Unary: LeftSemiAnti join pushdown") { + test("Unary: LeftSemi join pushdown") { val originalQuery = testRelation .select(star()) .repartition(1) @@ -274,7 +274,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Unary: LeftSemiAnti join pushdown - empty join condition") { + test("Unary: LeftSemi join pushdown - empty join condition") { val originalQuery = testRelation .select(star()) .repartition(1) @@ -289,7 +289,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Unary: LeftSemi join pushdown - partial pushdown") { + test("Unary: LeftSemi join partial pushdown") { val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) val originalQuery = testRelationWithArrayType .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) @@ -305,7 +305,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Unary: LeftAnti join pushdown - no pushdown") { + test("Unary: LeftAnti join no pushdown") { val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) val originalQuery = testRelationWithArrayType .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) @@ -315,7 +315,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("Unary: LeftSemiAnti join pushdown - no pushdown") { + test("Unary: LeftSemi join - no pushdown") { val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) val originalQuery = testRelationWithArrayType .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) @@ -325,7 +325,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("Unary: LeftSemi join push down through Expand") { + test("Unary: LeftSemi join pushdown through Expand") { val expand = Expand(Seq(Seq('a, 'b, "null"), Seq('a, "null", 'c)), Seq('a, 'b, 'c), testRelation) val originalQuery = expand @@ -431,6 +431,25 @@ class LeftSemiPushdownSuite extends PlanTest { } } + Seq(LeftSemi, LeftAnti).foreach { case jt => + test(s"Aggregate: $jt join no pushdown - join condition refers left leg and right leg child") { + val aggregation = testRelation + .select('b.as("id"), 'c) + .groupBy('id)('id, sum('c).as("sum")) + + // reference "b" exists in left leg, and the children of the right leg of the join + val originalQuery = aggregation.select(('id + 1).as("id_plus_1"), 'sum) + .join(aggregation, joinType = jt, condition = Some('id === 'id_plus_1)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select('b.as("id"), 'c) + .groupBy('id)(('id + 1).as("id_plus_1"), sum('c).as("sum")) + .join(aggregation, joinType = jt, condition = Some('id === 'id_plus_1)) + .analyze + comparePlans(optimized, correctAnswer) + } + } + Seq(LeftSemi, LeftAnti).foreach { case outerJT => Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT => test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 1fda13f996a..4298d503b10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -288,6 +288,24 @@ class DataFrameJoinSuite extends QueryTest } } + Seq("left_semi", "left_anti").foreach { joinType => + test(s"SPARK-41162: $joinType self-joined aggregated dataframe") { + // aggregated dataframe + val ids = Seq(1, 2, 3).toDF("id").distinct() + + // self-joined via joinType + val result = ids.withColumn("id", $"id" + 1) + .join(ids, usingColumns = Seq("id"), joinType = joinType).collect() + + val expected = joinType match { + case "left_semi" => 2 + case "left_anti" => 1 + case _ => -1 // unsupported test type, test will always fail + } + assert(result.length == expected) + } + } + def extractLeftDeepInnerJoins(plan: LogicalPlan): Seq[LogicalPlan] = plan match { case j @ Join(left, right, _: InnerLike, _, _) => right +: extractLeftDeepInnerJoins(left) case Filter(_, child) => extractLeftDeepInnerJoins(child) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org