Repository: spark
Updated Branches:
  refs/heads/master 999ec137a -> 6ac57fd0d


[SPARK-21417][SQL] Infer join conditions using propagated constraints

## What changes were proposed in this pull request?

This PR adds an optimization rule that infers join conditions using propagated 
constraints.

For instance, if there is a join, where the left relation has 'a = 1' and the 
right relation has 'b = 1', then the rule infers 'a = b' as a join predicate. 
Only semantically new predicates are appended to the existing join condition.

Refer to the corresponding ticket and tests for more details.

## How was this patch tested?

This patch comes with a new test suite to cover the implemented logic.

Author: aokolnychyi <anton.okolnyc...@sap.com>

Closes #18692 from aokolnychyi/spark-21417.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6ac57fd0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6ac57fd0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6ac57fd0

Branch: refs/heads/master
Commit: 6ac57fd0d1c82b834eb4bf0dd57596b92a99d6de
Parents: 999ec13
Author: aokolnychyi <anton.okolnyc...@sap.com>
Authored: Thu Nov 30 14:25:10 2017 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Thu Nov 30 14:25:10 2017 -0800

----------------------------------------------------------------------
 .../expressions/EquivalentExpressionMap.scala   |  66 +++++
 .../catalyst/expressions/ExpressionSet.scala    |   2 +
 .../sql/catalyst/optimizer/Optimizer.scala      |   1 +
 .../spark/sql/catalyst/optimizer/joins.scala    |  60 +++++
 .../EquivalentExpressionMapSuite.scala          |  56 +++++
 .../optimizer/EliminateCrossJoinSuite.scala     | 238 +++++++++++++++++++
 6 files changed, 423 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6ac57fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala
new file mode 100644
index 0000000..cf1614a
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMap.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import scala.collection.mutable
+
+import 
org.apache.spark.sql.catalyst.expressions.EquivalentExpressionMap.SemanticallyEqualExpr
+
+/**
+ * A class that allows you to map an expression into a set of equivalent 
expressions. The keys are
+ * handled based on their semantic meaning and ignoring cosmetic differences. 
The values are
+ * represented as [[ExpressionSet]]s.
+ *
+ * The underlying representation of keys depends on the 
[[Expression.semanticHash]] and
+ * [[Expression.semanticEquals]] methods.
+ *
+ * {{{
+ *   val map = new EquivalentExpressionMap()
+ *
+ *   map.put(1 + 2, a)
+ *   map.put(rand(), b)
+ *
+ *   map.get(2 + 1) => Set(a) // 1 + 2 and 2 + 1 are semantically equivalent
+ *   map.get(1 + 2) => Set(a) // 1 + 2 and 2 + 1 are semantically equivalent
+ *   map.get(rand()) => Set() // non-deterministic expressions are not 
equivalent
+ * }}}
+ */
+class EquivalentExpressionMap {
+
+  private val equivalenceMap = mutable.HashMap.empty[SemanticallyEqualExpr, 
ExpressionSet]
+
+  def put(expression: Expression, equivalentExpression: Expression): Unit = {
+    val equivalentExpressions = equivalenceMap.getOrElseUpdate(expression, 
ExpressionSet.empty)
+    equivalenceMap(expression) = equivalentExpressions + equivalentExpression
+  }
+
+  def get(expression: Expression): Set[Expression] =
+    equivalenceMap.getOrElse(expression, ExpressionSet.empty)
+}
+
+object EquivalentExpressionMap {
+
+  private implicit class SemanticallyEqualExpr(val expr: Expression) {
+    override def equals(obj: Any): Boolean = obj match {
+      case other: SemanticallyEqualExpr => expr.semanticEquals(other.expr)
+      case _ => false
+    }
+
+    override def hashCode: Int = expr.semanticHash()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6ac57fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
index 7e8e7b8..e989083 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
@@ -27,6 +27,8 @@ object ExpressionSet {
     expressions.foreach(set.add)
     set
   }
+
+  val empty: ExpressionSet = ExpressionSet(Nil)
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/6ac57fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0d961bf..8a5c486 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -87,6 +87,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
       PushProjectionThroughUnion,
       ReorderJoin,
       EliminateOuterJoin,
+      EliminateCrossJoin,
       InferFiltersFromConstraints,
       BooleanSimplification,
       PushPredicateThroughJoin,

http://git-wip-us.apache.org/repos/asf/spark/blob/6ac57fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index edbeaf2..29a3a7f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.optimizer
 
 import scala.annotation.tailrec
+import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
@@ -152,3 +153,62 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with 
PredicateHelper {
       if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType 
= newJoinType))
   }
 }
+
+/**
+ * A rule that eliminates CROSS joins by inferring join conditions from 
propagated constraints.
+ *
+ * The optimization is applicable only to CROSS joins. For other join types, 
adding inferred join
+ * conditions would potentially shuffle children as child node's partitioning 
won't satisfy the JOIN
+ * node's requirements which otherwise could have.
+ *
+ * For instance, given a CROSS join with the constraint 'a = 1' from the left 
child and the
+ * constraint 'b = 1' from the right child, this rule infers a new join 
predicate 'a = b' and
+ * converts it to an Inner join.
+ */
+object EliminateCrossJoin extends Rule[LogicalPlan] with PredicateHelper {
+
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    if (SQLConf.get.constraintPropagationEnabled) {
+      eliminateCrossJoin(plan)
+    } else {
+      plan
+    }
+  }
+
+  private def eliminateCrossJoin(plan: LogicalPlan): LogicalPlan = plan 
transform {
+    case join @ Join(leftPlan, rightPlan, Cross, None) =>
+      val leftConstraints = 
join.constraints.filter(_.references.subsetOf(leftPlan.outputSet))
+      val rightConstraints = 
join.constraints.filter(_.references.subsetOf(rightPlan.outputSet))
+      val inferredJoinPredicates = inferJoinPredicates(leftConstraints, 
rightConstraints)
+      val joinConditionOpt = inferredJoinPredicates.reduceOption(And)
+      if (joinConditionOpt.isDefined) Join(leftPlan, rightPlan, Inner, 
joinConditionOpt) else join
+  }
+
+  private def inferJoinPredicates(
+      leftConstraints: Set[Expression],
+      rightConstraints: Set[Expression]): mutable.Set[EqualTo] = {
+
+    val equivalentExpressionMap = new EquivalentExpressionMap()
+
+    leftConstraints.foreach {
+      case EqualTo(attr: Attribute, expr: Expression) =>
+        equivalentExpressionMap.put(expr, attr)
+      case EqualTo(expr: Expression, attr: Attribute) =>
+        equivalentExpressionMap.put(expr, attr)
+      case _ =>
+    }
+
+    val joinConditions = mutable.Set.empty[EqualTo]
+
+    rightConstraints.foreach {
+      case EqualTo(attr: Attribute, expr: Expression) =>
+        joinConditions ++= equivalentExpressionMap.get(expr).map(EqualTo(attr, 
_))
+      case EqualTo(expr: Expression, attr: Attribute) =>
+        joinConditions ++= equivalentExpressionMap.get(expr).map(EqualTo(attr, 
_))
+      case _ =>
+    }
+
+    joinConditions
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6ac57fd0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala
new file mode 100644
index 0000000..bad7e17
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressionMapSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.dsl.expressions._
+
+class EquivalentExpressionMapSuite extends SparkFunSuite {
+
+  private val onePlusTwo = Literal(1) + Literal(2)
+  private val twoPlusOne = Literal(2) + Literal(1)
+  private val rand = Rand(10)
+
+  test("behaviour of the equivalent expression map") {
+    val equivalentExpressionMap = new EquivalentExpressionMap()
+    equivalentExpressionMap.put(onePlusTwo, 'a)
+    equivalentExpressionMap.put(Literal(1) + Literal(3), 'b)
+    equivalentExpressionMap.put(rand, 'c)
+
+    // 1 + 2 should be equivalent to 2 + 1
+    
assertResult(ExpressionSet(Seq('a)))(equivalentExpressionMap.get(twoPlusOne))
+    // non-deterministic expressions should not be equivalent
+    assertResult(ExpressionSet.empty)(equivalentExpressionMap.get(rand))
+
+    // if the same (key, value) is added several times, the map still returns 
only one entry
+    equivalentExpressionMap.put(onePlusTwo, 'a)
+    equivalentExpressionMap.put(twoPlusOne, 'a)
+    
assertResult(ExpressionSet(Seq('a)))(equivalentExpressionMap.get(twoPlusOne))
+
+    // get several equivalent attributes
+    equivalentExpressionMap.put(onePlusTwo, 'e)
+    assertResult(ExpressionSet(Seq('a, 
'e)))(equivalentExpressionMap.get(onePlusTwo))
+    assertResult(2)(equivalentExpressionMap.get(onePlusTwo).size)
+
+    // several non-deterministic expressions should not be equivalent
+    equivalentExpressionMap.put(rand, 'd)
+    assertResult(ExpressionSet.empty)(equivalentExpressionMap.get(rand))
+    assertResult(0)(equivalentExpressionMap.get(rand).size)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6ac57fd0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala
new file mode 100644
index 0000000..e04dd28
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateCrossJoinSuite.scala
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal, 
Not, Rand}
+import org.apache.spark.sql.catalyst.plans.{Cross, Inner, JoinType, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED
+import org.apache.spark.sql.types.IntegerType
+
+class EliminateCrossJoinSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("Eliminate cross joins", FixedPoint(10),
+        EliminateCrossJoin,
+        PushPredicateThroughJoin) :: Nil
+  }
+
+  val testRelation1 = LocalRelation('a.int, 'b.int)
+  val testRelation2 = LocalRelation('c.int, 'd.int)
+
+  test("successful elimination of cross joins (1)") {
+    checkJoinOptimization(
+      originalFilter = 'a === 1 && 'c === 1 && 'd === 1,
+      originalJoinType = Cross,
+      originalJoinCondition = None,
+      expectedFilter = None,
+      expectedLeftRelationFilter = 'a === 1,
+      expectedRightRelationFilter = 'c === 1 && 'd === 1,
+      expectedJoinType = Inner,
+      expectedJoinCondition = Some('a === 'c && 'a === 'd))
+  }
+
+  test("successful elimination of cross joins (2)") {
+    checkJoinOptimization(
+      originalFilter = 'a === 1 && 'b === 2 && 'd === 1,
+      originalJoinType = Cross,
+      originalJoinCondition = None,
+      expectedFilter = None,
+      expectedLeftRelationFilter = 'a === 1 && 'b === 2,
+      expectedRightRelationFilter = 'd === 1,
+      expectedJoinType = Inner,
+      expectedJoinCondition = Some('a === 'd))
+  }
+
+  test("successful elimination of cross joins (3)") {
+    // PushPredicateThroughJoin will push 'd === 'a into the join condition
+    // EliminateCrossJoin will NOT apply because the condition will be already 
present
+    // therefore, the join type will stay the same (i.e., CROSS)
+    checkJoinOptimization(
+      originalFilter = 'a === 1 && Literal(1) === 'd && 'd === 'a,
+      originalJoinType = Cross,
+      originalJoinCondition = None,
+      expectedFilter = None,
+      expectedLeftRelationFilter = 'a === 1,
+      expectedRightRelationFilter = Literal(1) === 'd,
+      expectedJoinType = Cross,
+      expectedJoinCondition = Some('a === 'd))
+  }
+
+  test("successful elimination of cross joins (4)") {
+    // Literal(1) * Literal(2) and Literal(2) * Literal(1) are semantically 
equal
+    checkJoinOptimization(
+      originalFilter = 'a === Literal(1) * Literal(2) && Literal(2) * 
Literal(1) === 'c,
+      originalJoinType = Cross,
+      originalJoinCondition = None,
+      expectedFilter = None,
+      expectedLeftRelationFilter = 'a === Literal(1) * Literal(2),
+      expectedRightRelationFilter = Literal(2) * Literal(1) === 'c,
+      expectedJoinType = Inner,
+      expectedJoinCondition = Some('a === 'c))
+  }
+
+  test("successful elimination of cross joins (5)") {
+    checkJoinOptimization(
+      originalFilter = 'a === 1 && Literal(1) === 'a && 'c === 1,
+      originalJoinType = Cross,
+      originalJoinCondition = None,
+      expectedFilter = None,
+      expectedLeftRelationFilter = 'a === 1 && Literal(1) === 'a,
+      expectedRightRelationFilter = 'c === 1,
+      expectedJoinType = Inner,
+      expectedJoinCondition = Some('a === 'c))
+  }
+
+  test("successful elimination of cross joins (6)") {
+    checkJoinOptimization(
+      originalFilter = 'a === Cast("1", IntegerType) && 'c === Cast("1", 
IntegerType) && 'd === 1,
+      originalJoinType = Cross,
+      originalJoinCondition = None,
+      expectedFilter = None,
+      expectedLeftRelationFilter = 'a === Cast("1", IntegerType),
+      expectedRightRelationFilter = 'c === Cast("1", IntegerType) && 'd === 1,
+      expectedJoinType = Inner,
+      expectedJoinCondition = Some('a === 'c))
+  }
+
+  test("successful elimination of cross joins (7)") {
+    // The join condition appears due to PushPredicateThroughJoin
+    checkJoinOptimization(
+      originalFilter = (('a >= 1 && 'c === 1) || 'd === 10) && 'b === 10 && 'c 
=== 1,
+      originalJoinType = Cross,
+      originalJoinCondition = None,
+      expectedFilter = None,
+      expectedLeftRelationFilter = 'b === 10,
+      expectedRightRelationFilter = 'c === 1,
+      expectedJoinType = Cross,
+      expectedJoinCondition = Some(('a >= 1 && 'c === 1) || 'd === 10))
+  }
+
+  test("successful elimination of cross joins (8)") {
+    checkJoinOptimization(
+      originalFilter = 'a === 1 && 'c === 1 && Literal(1) === 'a && Literal(1) 
=== 'c,
+      originalJoinType = Cross,
+      originalJoinCondition = None,
+      expectedFilter = None,
+      expectedLeftRelationFilter = 'a === 1 && Literal(1) === 'a,
+      expectedRightRelationFilter = 'c === 1 && Literal(1) === 'c,
+      expectedJoinType = Inner,
+      expectedJoinCondition = Some('a === 'c))
+  }
+
+  test("inability to detect join conditions when constant propagation is 
disabled") {
+    withSQLConf(CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
+      checkJoinOptimization(
+        originalFilter = 'a === 1 && 'c === 1 && 'd === 1,
+        originalJoinType = Cross,
+        originalJoinCondition = None,
+        expectedFilter = None,
+        expectedLeftRelationFilter = 'a === 1,
+        expectedRightRelationFilter = 'c === 1 && 'd === 1,
+        expectedJoinType = Cross,
+        expectedJoinCondition = None)
+    }
+  }
+
+  test("inability to detect join conditions (1)") {
+    checkJoinOptimization(
+      originalFilter = 'a >= 1 && 'c === 1 && 'd >= 1,
+      originalJoinType = Cross,
+      originalJoinCondition = None,
+      expectedFilter = None,
+      expectedLeftRelationFilter = 'a >= 1,
+      expectedRightRelationFilter = 'c === 1 && 'd >= 1,
+      expectedJoinType = Cross,
+      expectedJoinCondition = None)
+  }
+
+  test("inability to detect join conditions (2)") {
+    checkJoinOptimization(
+      originalFilter = Literal(1) === 'b && ('c === 1 || 'd === 1),
+      originalJoinType = Cross,
+      originalJoinCondition = None,
+      expectedFilter = None,
+      expectedLeftRelationFilter = Literal(1) === 'b,
+      expectedRightRelationFilter = 'c === 1 || 'd === 1,
+      expectedJoinType = Cross,
+      expectedJoinCondition = None)
+  }
+
+  test("inability to detect join conditions (3)") {
+    checkJoinOptimization(
+      originalFilter = Literal(1) === 'b && 'c === 1,
+      originalJoinType = Cross,
+      originalJoinCondition = Some('c === 'b),
+      expectedFilter = None,
+      expectedLeftRelationFilter = Literal(1) === 'b,
+      expectedRightRelationFilter = 'c === 1,
+      expectedJoinType = Cross,
+      expectedJoinCondition = Some('c === 'b))
+  }
+
+  test("inability to detect join conditions (4)") {
+    checkJoinOptimization(
+      originalFilter = Not('a === 1) && 'd === 1,
+      originalJoinType = Cross,
+      originalJoinCondition = None,
+      expectedFilter = None,
+      expectedLeftRelationFilter = Not('a === 1),
+      expectedRightRelationFilter = 'd === 1,
+      expectedJoinType = Cross,
+      expectedJoinCondition = None)
+  }
+
+  test("inability to detect join conditions (5)") {
+    checkJoinOptimization(
+      originalFilter = 'a === Rand(10) && 'b === 1 && 'd === Rand(10) && 'c 
=== 3,
+      originalJoinType = Cross,
+      originalJoinCondition = None,
+      expectedFilter = Some('a === Rand(10) && 'd === Rand(10)),
+      expectedLeftRelationFilter = 'b === 1,
+      expectedRightRelationFilter = 'c === 3,
+      expectedJoinType = Cross,
+      expectedJoinCondition = None)
+  }
+
+  private def checkJoinOptimization(
+      originalFilter: Expression,
+      originalJoinType: JoinType,
+      originalJoinCondition: Option[Expression],
+      expectedFilter: Option[Expression],
+      expectedLeftRelationFilter: Expression,
+      expectedRightRelationFilter: Expression,
+      expectedJoinType: JoinType,
+      expectedJoinCondition: Option[Expression]): Unit = {
+
+    val originalQuery = testRelation1
+      .join(testRelation2, originalJoinType, originalJoinCondition)
+      .where(originalFilter)
+    val optimizedQuery = Optimize.execute(originalQuery.analyze)
+
+    val left = testRelation1.where(expectedLeftRelationFilter)
+    val right = testRelation2.where(expectedRightRelationFilter)
+    val join = left.join(right, expectedJoinType, expectedJoinCondition)
+    val expectedQuery = expectedFilter.fold(join)(join.where(_)).analyze
+
+    comparePlans(optimizedQuery, expectedQuery)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to