This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b8b5acd [SPARK-19712][SQL][FOLLOW-UP] Don't do partial pushdown when pushing down LeftAnti joins below Aggregate or Window operators. b8b5acd is described below commit b8b5acdd417f28c60c784159253a8974fa738904 Author: Dilip Biswal <dbis...@us.ibm.com> AuthorDate: Wed Apr 3 09:56:27 2019 +0800 [SPARK-19712][SQL][FOLLOW-UP] Don't do partial pushdown when pushing down LeftAnti joins below Aggregate or Window operators. ## What changes were proposed in this pull request? After [23750](https://github.com/apache/spark/pull/23750), we may pushdown left anti joins below aggregate and window operators with a partial join condition. This is not correct and was pointed out by hvanhovell and cloud-fan [here](https://github.com/apache/spark/pull/23750#discussion_r270017097). This pr addresses their comments. ## How was this patch tested? Added two new tests to verify the behaviour. Closes #24253 from dilipbiswal/SPARK-19712-followup. Authored-by: Dilip Biswal <dbis...@us.ibm.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../optimizer/PushDownLeftSemiAntiJoin.scala | 35 +++++++++++++++-- .../optimizer/LeftSemiAntiJoinPushDownSuite.scala | 44 ++++++++++++++++++++-- 2 files changed, 73 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index bc868df..afe2cfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -82,7 +82,18 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { val newAgg = agg.copy(child = Join(agg.child, rightOp, joinType, Option(replaced), hint)) // If there is no more filter to stay up, just return the Aggregate over Join. // Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)". - if (stayUp.isEmpty) newAgg else Filter(stayUp.reduce(And), newAgg) + if (stayUp.isEmpty) { + newAgg + } else { + joinType match { + // In case of Left semi join, the part of the join condition which does not refer to + // to child attributes of the aggregate operator are kept as a Filter over window. + case LeftSemi => Filter(stayUp.reduce(And), newAgg) + // In case of left anti join, the join is pushed down when the entire join condition + // is eligible to be pushed down to preserve the semantics of left anti join. + case _ => join + } + } } else { // The join condition is not a subset of the Aggregate's GROUP BY columns, // no push down. @@ -114,7 +125,18 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { if (pushDown.nonEmpty && rightOpColumns.isEmpty) { val predicate = pushDown.reduce(And) val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(predicate), hint)) - if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan) + if (stayUp.isEmpty) { + newPlan + } else { + joinType match { + // In case of Left semi join, the part of the join condition which does not refer to + // to partition attributes of the window operator are kept as a Filter over window. + case LeftSemi => Filter(stayUp.reduce(And), newPlan) + // In case of left anti join, the join is pushed down when the entire join condition + // is eligible to be pushed down to preserve the semantics of left anti join. + case _ => join + } + } } else { // The join condition is not a subset of the Window's PARTITION BY clause, // no push down. @@ -184,7 +206,14 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { if (pushDown.nonEmpty && rightOpColumns.isEmpty) { val newChild = insertJoin(Option(pushDown.reduceLeft(And))) if (stayUp.nonEmpty) { - Filter(stayUp.reduceLeft(And), newChild) + join.joinType match { + // In case of Left semi join, the part of the join condition which does not refer to + // to attributes of the grandchild are kept as a Filter over window. + case LeftSemi => Filter(stayUp.reduce(And), newChild) + // In case of left anti join, the join is pushed down when the entire join condition + // is eligible to be pushed down to preserve the semantics of left anti join. + case _ => join + } } else { newChild } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index 1a0231e..185568d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -117,7 +117,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("Aggregate: LeftSemiAnti join partial pushdown") { + test("Aggregate: LeftSemi join partial pushdown") { val originalQuery = testRelation .groupBy('b)('b, sum('c).as('sum)) .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 10)) @@ -132,6 +132,15 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Aggregate: LeftAnti join no pushdown") { + val originalQuery = testRelation + .groupBy('b)('b, sum('c).as('sum)) + .join(testRelation1, joinType = LeftAnti, condition = Some('b === 'd && 'sum === 10)) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + test("LeftSemiAnti join over aggregate - no pushdown") { val originalQuery = testRelation .groupBy('b)('b, sum('c).as('sum)) @@ -174,7 +183,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Window: LeftSemiAnti partial pushdown") { + test("Window: LeftSemi partial pushdown") { // Attributes from join condition which does not refer to the window partition spec // are kept up in the plan as a Filter operator above Window. val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) @@ -195,6 +204,25 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Window: LeftAnti no pushdown") { + // Attributes from join condition which does not refer to the window partition spec + // are kept up in the plan as a Filter operator above Window. + val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) + + val originalQuery = testRelation + .select('a, 'b, 'c, winExpr.as('window)) + .join(testRelation1, joinType = LeftAnti, condition = Some('a === 'd && 'b > 5)) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b, 'c) + .window(winExpr.as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil) + .join(testRelation1, joinType = LeftAnti, condition = Some('a === 'd && 'b > 5)) + .select('a, 'b, 'c, 'window).analyze + comparePlans(optimized, correctAnswer) + } + test("Union: LeftSemiAnti join pushdown") { val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int) @@ -251,7 +279,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Unary: LeftSemiAnti join pushdown - partial pushdown") { + test("Unary: LeftSemi join pushdown - partial pushdown") { val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) val originalQuery = testRelationWithArrayType .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) @@ -267,6 +295,16 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Unary: LeftAnti join pushdown - no pushdown") { + val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) + val originalQuery = testRelationWithArrayType + .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) + .join(testRelation1, joinType = LeftAnti, condition = Some('b === 'd && 'b === 'out_col)) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + test("Unary: LeftSemiAnti join pushdown - no pushdown") { val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) val originalQuery = testRelationWithArrayType --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org