This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e1c90d6 [SPARK-19712][SQL] Pushdown LeftSemi/LeftAnti below join e1c90d6 is described below commit e1c90d66bbea5b4cb97226610701b0389b734651 Author: Dilip Biswal <dbis...@us.ibm.com> AuthorDate: Wed Apr 17 20:30:20 2019 +0800 [SPARK-19712][SQL] Pushdown LeftSemi/LeftAnti below join ## What changes were proposed in this pull request? This PR adds support for pushing down LeftSemi and LeftAnti joins below the Join operator. This is a prerequisite work thats needed for the subsequent task of moving the subquery rewrites to the beginning of optimization phase. The larger PR is [here](https://github.com/apache/spark/pull/23211) . This PR addresses the comment at [link](https://github.com/apache/spark/pull/23211#issuecomment-445705922). ## How was this patch tested? Added tests under LeftSemiAntiJoinPushDownSuite. Closes #24331 from dilipbiswal/SPARK-19712-pushleftsemi-belowjoin. Authored-by: Dilip Biswal <dbis...@us.ibm.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 1 + .../optimizer/PushDownLeftSemiAntiJoin.scala | 104 ++++++++++++++++++ .../optimizer/LeftSemiAntiJoinPushDownSuite.scala | 117 ++++++++++++++++++++- 3 files changed, 221 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d0368be..afdf61e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -66,6 +66,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) PushPredicateThroughJoin, PushDownPredicate, PushDownLeftSemiAntiJoin, + PushLeftSemiLeftAntiThroughJoin, LimitPushDown, ColumnPruning, InferFiltersFromConstraints, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index d91f262..0c38900 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -159,3 +159,107 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { } } } + +/** + * This rule is a variant of [[PushPredicateThroughJoin]] which can handle + * pushing down Left semi and Left Anti joins below a join operator. The + * allowable join types are: + * 1) Inner + * 2) Cross + * 3) LeftOuter + * 4) RightOuter + * + * TODO: + * Currently this rule can push down the left semi or left anti joins to either + * left or right leg of the child join. This matches the behaviour of `PushPredicateThroughJoin` + * when the lefi semi or left anti join is in expression form. We need to explore the possibility + * to push the left semi/anti joins to both legs of join if the join condition refers to + * both left and right legs of the child join. + */ +object PushLeftSemiLeftAntiThroughJoin extends Rule[LogicalPlan] with PredicateHelper { + /** + * Define an enumeration to identify whether a LeftSemi/LeftAnti join can be pushed down to + * the left leg or the right leg of the join. + */ + object PushdownDirection extends Enumeration { + val TO_LEFT_BRANCH, TO_RIGHT_BRANCH, NONE = Value + } + + object AllowedJoin { + def unapply(join: Join): Option[Join] = join.joinType match { + case Inner | Cross | LeftOuter | RightOuter => Some(join) + case _ => None + } + } + + /** + * Determine which side of the join a LeftSemi/LeftAnti join can be pushed to. + */ + private def pushTo(leftChild: Join, rightChild: LogicalPlan, joinCond: Option[Expression]) = { + val left = leftChild.left + val right = leftChild.right + val joinType = leftChild.joinType + val rightOutput = rightChild.outputSet + + if (joinCond.nonEmpty) { + val conditions = splitConjunctivePredicates(joinCond.get) + val (leftConditions, rest) = + conditions.partition(_.references.subsetOf(left.outputSet ++ rightOutput)) + val (rightConditions, commonConditions) = + rest.partition(_.references.subsetOf(right.outputSet ++ rightOutput)) + + if (rest.isEmpty && leftConditions.nonEmpty) { + // When the join conditions can be computed based on the left leg of + // leftsemi/anti join then push the leftsemi/anti join to the left side. + PushdownDirection.TO_LEFT_BRANCH + } else if (leftConditions.isEmpty && rightConditions.nonEmpty && commonConditions.isEmpty) { + // When the join conditions can be computed based on the attributes from right leg of + // leftsemi/anti join then push the leftsemi/anti join to the right side. + PushdownDirection.TO_RIGHT_BRANCH + } else { + PushdownDirection.NONE + } + } else { + /** + * When the join condition is empty, + * 1) if this is a left outer join or inner join, push leftsemi/anti join down + * to the left leg of join. + * 2) if a right outer join, to the right leg of join, + */ + joinType match { + case _: InnerLike | LeftOuter => + PushdownDirection.TO_LEFT_BRANCH + case RightOuter => + PushdownDirection.TO_RIGHT_BRANCH + case _ => + PushdownDirection.NONE + } + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // push LeftSemi/LeftAnti down into the join below + case j @ Join(AllowedJoin(left), right, LeftSemiOrAnti(joinType), joinCond, parentHint) => + val (childJoinType, childLeft, childRight, childCondition, childHint) = + (left.joinType, left.left, left.right, left.condition, left.hint) + val action = pushTo(left, right, joinCond) + + action match { + case PushdownDirection.TO_LEFT_BRANCH + if (childJoinType == LeftOuter || childJoinType.isInstanceOf[InnerLike]) => + // push down leftsemi/anti join to the left table + val newLeft = Join(childLeft, right, joinType, joinCond, parentHint) + Join(newLeft, childRight, childJoinType, childCondition, childHint) + case PushdownDirection.TO_RIGHT_BRANCH + if (childJoinType == RightOuter || childJoinType.isInstanceOf[InnerLike]) => + // push down leftsemi/anti join to the right table + val newRight = Join(childRight, right, joinType, joinCond, parentHint) + Join(childLeft, newRight, childJoinType, childCondition, childHint) + case _ => + // Do nothing + j + } + } +} + + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index 185568d..00709ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -37,13 +37,14 @@ class LeftSemiPushdownSuite extends PlanTest { CombineFilters, PushDownPredicate, PushDownLeftSemiAntiJoin, + PushLeftSemiLeftAntiThroughJoin, BooleanSimplification, CollapseProject) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation1 = LocalRelation('d.int) + val testRelation2 = LocalRelation('e.int) test("Project: LeftSemiAnti join pushdown") { val originalQuery = testRelation @@ -314,4 +315,118 @@ class LeftSemiPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) } + + Seq(Some('d === 'e), None).foreach { case innerJoinCond => + Seq(LeftSemi, LeftAnti).foreach { case outerJT => + Seq(Inner, LeftOuter, Cross, RightOuter).foreach { case innerJT => + test(s"$outerJT pushdown empty join cond join type $innerJT join cond $innerJoinCond") { + val joinedRelation = testRelation1.join(testRelation2, joinType = innerJT, innerJoinCond) + val originalQuery = joinedRelation.join(testRelation, joinType = outerJT, None) + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = if (innerJT == RightOuter) { + val pushedDownJoin = testRelation2.join(testRelation, joinType = outerJT, None) + testRelation1.join(pushedDownJoin, joinType = innerJT, innerJoinCond).analyze + } else { + val pushedDownJoin = testRelation1.join(testRelation, joinType = outerJT, None) + pushedDownJoin.join(testRelation2, joinType = innerJT, innerJoinCond).analyze + } + comparePlans(optimized, correctAnswer) + } + } + } + } + + Seq(Some('d === 'e), None).foreach { case innerJoinCond => + Seq(LeftSemi, LeftAnti).foreach { case outerJT => + Seq(Inner, LeftOuter, Cross).foreach { case innerJT => + test(s"$outerJT pushdown to left of join type: $innerJT join condition $innerJoinCond") { + val joinedRelation = testRelation1.join(testRelation2, joinType = innerJT, innerJoinCond) + val originalQuery = + joinedRelation.join(testRelation, joinType = outerJT, condition = Some('a === 'd)) + val optimized = Optimize.execute(originalQuery.analyze) + + val pushedDownJoin = + testRelation1.join(testRelation, joinType = outerJT, condition = Some('a === 'd)) + val correctAnswer = + pushedDownJoin.join(testRelation2, joinType = innerJT, innerJoinCond).analyze + comparePlans(optimized, correctAnswer) + } + } + } + } + + Seq(Some('e === 'd), None).foreach { case innerJoinCond => + Seq(LeftSemi, LeftAnti).foreach { case outerJT => + Seq(Inner, RightOuter, Cross).foreach { case innerJT => + test(s"$outerJT pushdown to right of join type: $innerJT join condition $innerJoinCond") { + val joinedRelation = testRelation1.join(testRelation2, joinType = innerJT, innerJoinCond) + val originalQuery = + joinedRelation.join(testRelation, joinType = outerJT, condition = Some('a === 'e)) + val optimized = Optimize.execute(originalQuery.analyze) + + val pushedDownJoin = + testRelation2.join(testRelation, joinType = outerJT, condition = Some('a === 'e)) + val correctAnswer = + testRelation1.join(pushedDownJoin, joinType = innerJT, innerJoinCond).analyze + comparePlans(optimized, correctAnswer) + } + } + } + } + + Seq(LeftSemi, LeftAnti).foreach { case jt => + test(s"$jt no pushdown - join condition refers left leg - join type for RightOuter") { + val joinedRelation = testRelation1.join(testRelation2, joinType = RightOuter, None) + val originalQuery = + joinedRelation.join(testRelation, joinType = jt, condition = Some('a === 'd)) + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + } + + Seq(LeftSemi, LeftAnti).foreach { case jt => + test(s"$jt no pushdown - join condition refers right leg - join type for LeftOuter") { + val joinedRelation = testRelation1.join(testRelation2, joinType = LeftOuter, None) + val originalQuery = + joinedRelation.join(testRelation, joinType = jt, condition = Some('a === 'e)) + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + } + + Seq(LeftSemi, LeftAnti).foreach { case outerJT => + Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT => + test(s"$outerJT no pushdown - join condition refers both leg - join type $innerJT") { + val joinedRelation = testRelation1.join(testRelation2, joinType = innerJT, None) + val originalQuery = joinedRelation + .join(testRelation, joinType = outerJT, condition = Some('a === 'd && 'a === 'e)) + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + } + } + + Seq(LeftSemi, LeftAnti).foreach { case outerJT => + Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT => + test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") { + val joinedRelation = testRelation1.join(testRelation2, joinType = innerJT, None) + val originalQuery = joinedRelation + .join(testRelation, joinType = outerJT, condition = Some('d + 'e === 'a)) + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + } + } + + Seq(LeftSemi, LeftAnti).foreach { case jt => + test(s"$jt no pushdown when child join type is FullOuter") { + val joinedRelation = testRelation1.join(testRelation2, joinType = FullOuter, None) + val originalQuery = + joinedRelation.join(testRelation, joinType = jt, condition = Some('a === 'e)) + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org