Github user aokolnychyi commented on a diff in the pull request: https://github.com/apache/spark/pull/18692#discussion_r137343500 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala --- @@ -152,3 +152,71 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) } } + +/** + * A rule that uses propagated constraints to infer join conditions. The optimization is applicable + * only to CROSS joins. + * + * For instance, if there is a CROSS join, where the left relation has 'a = 1' and the right + * relation has 'b = 1', then the rule infers 'a = b' as a join predicate. + */ +object InferJoinConditionsFromConstraints extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = { + if (SQLConf.get.constraintPropagationEnabled) { + inferJoinConditions(plan) + } else { + plan + } + } + + private def inferJoinConditions(plan: LogicalPlan): LogicalPlan = plan transform { + case join @ Join(left, right, Cross, conditionOpt) => + val leftConstraints = join.constraints.filter(_.references.subsetOf(left.outputSet)) + val rightConstraints = join.constraints.filter(_.references.subsetOf(right.outputSet)) + val inferredJoinPredicates = inferJoinPredicates(leftConstraints, rightConstraints) + + val newConditionOpt = conditionOpt match { + case Some(condition) => + val existingPredicates = splitConjunctivePredicates(condition) + val newPredicates = findNewPredicates(inferredJoinPredicates, existingPredicates) + if (newPredicates.nonEmpty) Some(And(newPredicates.reduce(And), condition)) else None + case None => + inferredJoinPredicates.reduceOption(And) + } + if (newConditionOpt.isDefined) Join(left, right, Inner, newConditionOpt) else join --- End diff -- And what about CROSS joins with join conditions? Not sure if they will benefit from the proposed rule, but it is better to ask. ``` Seq((1, 2)).toDF("col1", "col2").write.saveAsTable("t1") Seq((1, 2)).toDF("col1", "col2").write.saveAsTable("t2") val df = spark.sql("SELECT * FROM t1 CROSS JOIN t2 ON t1.col1 >= t2.col1 WHERE t1.col1 = 1 AND t2.col1 = 1") df.explain(true) == Optimized Logical Plan == Join Cross, (col1#40 >= col1#42) :- Filter (isnotnull(col1#40) && (col1#40 = 1)) : +- Relation[col1#40,col2#41] parquet +- Filter (isnotnull(col1#42) && (col1#42 = 1)) +- Relation[col1#42,col2#43] parquet ```
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org