Repository: spark Updated Branches: refs/heads/master 999ec137a -> 6ac57fd0d
[SPARK-21417][SQL] Infer join conditions using propagated constraints ## What changes were proposed in this pull request? This PR adds an optimization rule that infers join conditions using propagated constraints. For instance, if there is a join, where the left relation has 'a = 1' and the right relation has 'b = 1', then the rule infers 'a = b' as a join predicate. Only semantically new predicates are appended to the existing join condition. Refer to the corresponding ticket and tests for more details. ## How was this patch tested? This patch comes with a new test suite to cover the implemented logic. Author: aokolnychyi <anton.okolnyc...@sap.com> Closes #18692 from aokolnychyi/spark-21417. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6ac57fd0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6ac57fd0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6ac57fd0 Branch: refs/heads/master Commit: 6ac57fd0d1c82b834eb4bf0dd57596b92a99d6de Parents: 999ec13 Author: aokolnychyi <anton.okolnyc...@sap.com> Authored: Thu Nov 30 14:25:10 2017 -0800 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Thu Nov 30 14:25:10 2017 -0800 ---------------------------------------------------------------------- .../expressions/EquivalentExpressionMap.scala | 66 +++++ .../catalyst/expressions/ExpressionSet.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../spark/sql/catalyst/optimizer/joins.scala | 60 +++++ .../EquivalentExpressionMapSuite.scala | 56 +++++ .../optimizer/EliminateCrossJoinSuite.scala | 238 +++++++++++++++++++ 6 files changed, 423 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6ac57fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala new file mode 100644 index 0000000..cf1614a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.EquivalentExpressionMap.SemanticallyEqualExpr + +/** + * A class that allows you to map an expression into a set of equivalent expressions. The keys are + * handled based on their semantic meaning and ignoring cosmetic differences. The values are + * represented as [[ExpressionSet]]s. + * + * The underlying representation of keys depends on the [[Expression.semanticHash]] and + * [[Expression.semanticEquals]] methods. + * + * {{{ + * val map = new EquivalentExpressionMap() + * + * map.put(1 + 2, a) + * map.put(rand(), b) + * + * map.get(2 + 1) => Set(a) // 1 + 2 and 2 + 1 are semantically equivalent + * map.get(1 + 2) => Set(a) // 1 + 2 and 2 + 1 are semantically equivalent + * map.get(rand()) => Set() // non-deterministic expressions are not equivalent + * }}} + */ +class EquivalentExpressionMap { + + private val equivalenceMap = mutable.HashMap.empty[SemanticallyEqualExpr, ExpressionSet] + + def put(expression: Expression, equivalentExpression: Expression): Unit = { + val equivalentExpressions = equivalenceMap.getOrElseUpdate(expression, ExpressionSet.empty) + equivalenceMap(expression) = equivalentExpressions + equivalentExpression + } + + def get(expression: Expression): Set[Expression] = + equivalenceMap.getOrElse(expression, ExpressionSet.empty) +} + +object EquivalentExpressionMap { + + private implicit class SemanticallyEqualExpr(val expr: Expression) { + override def equals(obj: Any): Boolean = obj match { + case other: SemanticallyEqualExpr => expr.semanticEquals(other.expr) + case _ => false + } + + override def hashCode: Int = expr.semanticHash() + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/6ac57fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index 7e8e7b8..e989083 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -27,6 +27,8 @@ object ExpressionSet { expressions.foreach(set.add) set } + + val empty: ExpressionSet = ExpressionSet(Nil) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/6ac57fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0d961bf..8a5c486 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -87,6 +87,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) PushProjectionThroughUnion, ReorderJoin, EliminateOuterJoin, + EliminateCrossJoin, InferFiltersFromConstraints, BooleanSimplification, PushPredicateThroughJoin, http://git-wip-us.apache.org/repos/asf/spark/blob/6ac57fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index edbeaf2..29a3a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec +import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins @@ -152,3 +153,62 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) } } + +/** + * A rule that eliminates CROSS joins by inferring join conditions from propagated constraints. + * + * The optimization is applicable only to CROSS joins. For other join types, adding inferred join + * conditions would potentially shuffle children as child node's partitioning won't satisfy the JOIN + * node's requirements which otherwise could have. + * + * For instance, given a CROSS join with the constraint 'a = 1' from the left child and the + * constraint 'b = 1' from the right child, this rule infers a new join predicate 'a = b' and + * converts it to an Inner join. + */ +object EliminateCrossJoin extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = { + if (SQLConf.get.constraintPropagationEnabled) { + eliminateCrossJoin(plan) + } else { + plan + } + } + + private def eliminateCrossJoin(plan: LogicalPlan): LogicalPlan = plan transform { + case join @ Join(leftPlan, rightPlan, Cross, None) => + val leftConstraints = join.constraints.filter(_.references.subsetOf(leftPlan.outputSet)) + val rightConstraints = join.constraints.filter(_.references.subsetOf(rightPlan.outputSet)) + val inferredJoinPredicates = inferJoinPredicates(leftConstraints, rightConstraints) + val joinConditionOpt = inferredJoinPredicates.reduceOption(And) + if (joinConditionOpt.isDefined) Join(leftPlan, rightPlan, Inner, joinConditionOpt) else join + } + + private def inferJoinPredicates( + leftConstraints: Set[Expression], + rightConstraints: Set[Expression]): mutable.Set[EqualTo] = { + + val equivalentExpressionMap = new EquivalentExpressionMap() + + leftConstraints.foreach { + case EqualTo(attr: Attribute, expr: Expression) => + equivalentExpressionMap.put(expr, attr) + case EqualTo(expr: Expression, attr: Attribute) => + equivalentExpressionMap.put(expr, attr) + case _ => + } + + val joinConditions = mutable.Set.empty[EqualTo] + + rightConstraints.foreach { + case EqualTo(attr: Attribute, expr: Expression) => + joinConditions ++= equivalentExpressionMap.get(expr).map(EqualTo(attr, _)) + case EqualTo(expr: Expression, attr: Attribute) => + joinConditions ++= equivalentExpressionMap.get(expr).map(EqualTo(attr, _)) + case _ => + } + + joinConditions + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/6ac57fd0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala new file mode 100644 index 0000000..bad7e17 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class EquivalentExpressionMapSuite extends SparkFunSuite { + + private val onePlusTwo = Literal(1) + Literal(2) + private val twoPlusOne = Literal(2) + Literal(1) + private val rand = Rand(10) + + test("behaviour of the equivalent expression map") { + val equivalentExpressionMap = new EquivalentExpressionMap() + equivalentExpressionMap.put(onePlusTwo, 'a) + equivalentExpressionMap.put(Literal(1) + Literal(3), 'b) + equivalentExpressionMap.put(rand, 'c) + + // 1 + 2 should be equivalent to 2 + 1 + assertResult(ExpressionSet(Seq('a)))(equivalentExpressionMap.get(twoPlusOne)) + // non-deterministic expressions should not be equivalent + assertResult(ExpressionSet.empty)(equivalentExpressionMap.get(rand)) + + // if the same (key, value) is added several times, the map still returns only one entry + equivalentExpressionMap.put(onePlusTwo, 'a) + equivalentExpressionMap.put(twoPlusOne, 'a) + assertResult(ExpressionSet(Seq('a)))(equivalentExpressionMap.get(twoPlusOne)) + + // get several equivalent attributes + equivalentExpressionMap.put(onePlusTwo, 'e) + assertResult(ExpressionSet(Seq('a, 'e)))(equivalentExpressionMap.get(onePlusTwo)) + assertResult(2)(equivalentExpressionMap.get(onePlusTwo).size) + + // several non-deterministic expressions should not be equivalent + equivalentExpressionMap.put(rand, 'd) + assertResult(ExpressionSet.empty)(equivalentExpressionMap.get(rand)) + assertResult(0)(equivalentExpressionMap.get(rand).size) + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/6ac57fd0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala new file mode 100644 index 0000000..e04dd28 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal, Not, Rand} +import org.apache.spark.sql.catalyst.plans.{Cross, Inner, JoinType, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED +import org.apache.spark.sql.types.IntegerType + +class EliminateCrossJoinSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Eliminate cross joins", FixedPoint(10), + EliminateCrossJoin, + PushPredicateThroughJoin) :: Nil + } + + val testRelation1 = LocalRelation('a.int, 'b.int) + val testRelation2 = LocalRelation('c.int, 'd.int) + + test("successful elimination of cross joins (1)") { + checkJoinOptimization( + originalFilter = 'a === 1 && 'c === 1 && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1, + expectedRightRelationFilter = 'c === 1 && 'd === 1, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c && 'a === 'd)) + } + + test("successful elimination of cross joins (2)") { + checkJoinOptimization( + originalFilter = 'a === 1 && 'b === 2 && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1 && 'b === 2, + expectedRightRelationFilter = 'd === 1, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'd)) + } + + test("successful elimination of cross joins (3)") { + // PushPredicateThroughJoin will push 'd === 'a into the join condition + // EliminateCrossJoin will NOT apply because the condition will be already present + // therefore, the join type will stay the same (i.e., CROSS) + checkJoinOptimization( + originalFilter = 'a === 1 && Literal(1) === 'd && 'd === 'a, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1, + expectedRightRelationFilter = Literal(1) === 'd, + expectedJoinType = Cross, + expectedJoinCondition = Some('a === 'd)) + } + + test("successful elimination of cross joins (4)") { + // Literal(1) * Literal(2) and Literal(2) * Literal(1) are semantically equal + checkJoinOptimization( + originalFilter = 'a === Literal(1) * Literal(2) && Literal(2) * Literal(1) === 'c, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === Literal(1) * Literal(2), + expectedRightRelationFilter = Literal(2) * Literal(1) === 'c, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c)) + } + + test("successful elimination of cross joins (5)") { + checkJoinOptimization( + originalFilter = 'a === 1 && Literal(1) === 'a && 'c === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1 && Literal(1) === 'a, + expectedRightRelationFilter = 'c === 1, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c)) + } + + test("successful elimination of cross joins (6)") { + checkJoinOptimization( + originalFilter = 'a === Cast("1", IntegerType) && 'c === Cast("1", IntegerType) && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === Cast("1", IntegerType), + expectedRightRelationFilter = 'c === Cast("1", IntegerType) && 'd === 1, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c)) + } + + test("successful elimination of cross joins (7)") { + // The join condition appears due to PushPredicateThroughJoin + checkJoinOptimization( + originalFilter = (('a >= 1 && 'c === 1) || 'd === 10) && 'b === 10 && 'c === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'b === 10, + expectedRightRelationFilter = 'c === 1, + expectedJoinType = Cross, + expectedJoinCondition = Some(('a >= 1 && 'c === 1) || 'd === 10)) + } + + test("successful elimination of cross joins (8)") { + checkJoinOptimization( + originalFilter = 'a === 1 && 'c === 1 && Literal(1) === 'a && Literal(1) === 'c, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1 && Literal(1) === 'a, + expectedRightRelationFilter = 'c === 1 && Literal(1) === 'c, + expectedJoinType = Inner, + expectedJoinCondition = Some('a === 'c)) + } + + test("inability to detect join conditions when constant propagation is disabled") { + withSQLConf(CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { + checkJoinOptimization( + originalFilter = 'a === 1 && 'c === 1 && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a === 1, + expectedRightRelationFilter = 'c === 1 && 'd === 1, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + } + + test("inability to detect join conditions (1)") { + checkJoinOptimization( + originalFilter = 'a >= 1 && 'c === 1 && 'd >= 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = 'a >= 1, + expectedRightRelationFilter = 'c === 1 && 'd >= 1, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + + test("inability to detect join conditions (2)") { + checkJoinOptimization( + originalFilter = Literal(1) === 'b && ('c === 1 || 'd === 1), + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = Literal(1) === 'b, + expectedRightRelationFilter = 'c === 1 || 'd === 1, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + + test("inability to detect join conditions (3)") { + checkJoinOptimization( + originalFilter = Literal(1) === 'b && 'c === 1, + originalJoinType = Cross, + originalJoinCondition = Some('c === 'b), + expectedFilter = None, + expectedLeftRelationFilter = Literal(1) === 'b, + expectedRightRelationFilter = 'c === 1, + expectedJoinType = Cross, + expectedJoinCondition = Some('c === 'b)) + } + + test("inability to detect join conditions (4)") { + checkJoinOptimization( + originalFilter = Not('a === 1) && 'd === 1, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = None, + expectedLeftRelationFilter = Not('a === 1), + expectedRightRelationFilter = 'd === 1, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + + test("inability to detect join conditions (5)") { + checkJoinOptimization( + originalFilter = 'a === Rand(10) && 'b === 1 && 'd === Rand(10) && 'c === 3, + originalJoinType = Cross, + originalJoinCondition = None, + expectedFilter = Some('a === Rand(10) && 'd === Rand(10)), + expectedLeftRelationFilter = 'b === 1, + expectedRightRelationFilter = 'c === 3, + expectedJoinType = Cross, + expectedJoinCondition = None) + } + + private def checkJoinOptimization( + originalFilter: Expression, + originalJoinType: JoinType, + originalJoinCondition: Option[Expression], + expectedFilter: Option[Expression], + expectedLeftRelationFilter: Expression, + expectedRightRelationFilter: Expression, + expectedJoinType: JoinType, + expectedJoinCondition: Option[Expression]): Unit = { + + val originalQuery = testRelation1 + .join(testRelation2, originalJoinType, originalJoinCondition) + .where(originalFilter) + val optimizedQuery = Optimize.execute(originalQuery.analyze) + + val left = testRelation1.where(expectedLeftRelationFilter) + val right = testRelation2.where(expectedRightRelationFilter) + val join = left.join(right, expectedJoinType, expectedJoinCondition) + val expectedQuery = expectedFilter.fold(join)(join.where(_)).analyze + + comparePlans(optimizedQuery, expectedQuery) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org