Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21083#discussion_r183299130
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 ---
    @@ -664,53 +662,52 @@ object InferFiltersFromConstraints extends 
Rule[LogicalPlan] with PredicateHelpe
           }
     
         case join @ Join(left, right, joinType, conditionOpt) =>
    -      // Only consider constraints that can be pushed down completely to 
either the left or the
    -      // right child
    -      val constraints = join.allConstraints.filter { c =>
    -        c.references.subsetOf(left.outputSet) || 
c.references.subsetOf(right.outputSet)
    -      }
    -      // Remove those constraints that are already enforced by either the 
left or the right child
    -      val additionalConstraints = constraints -- (left.constraints ++ 
right.constraints)
    -      val newConditionOpt = conditionOpt match {
    -        case Some(condition) =>
    -          val newFilters = additionalConstraints -- 
splitConjunctivePredicates(condition)
    -          if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), 
condition)) else conditionOpt
    -        case None =>
    -          additionalConstraints.reduceOption(And)
    -      }
    -      // Infer filter for left/right outer joins
    -      val newLeftOpt = joinType match {
    -        case RightOuter if newConditionOpt.isDefined =>
    -          val inferredConstraints = left.getRelevantConstraints(
    -            left.constraints
    -              .union(right.constraints)
    -              
.union(splitConjunctivePredicates(newConditionOpt.get).toSet))
    -          val newFilters = inferredConstraints
    -            .filterNot(left.constraints.contains)
    -            .reduceLeftOption(And)
    -          newFilters.map(Filter(_, left))
    -        case _ => None
    -      }
    -      val newRightOpt = joinType match {
    -        case LeftOuter if newConditionOpt.isDefined =>
    -          val inferredConstraints = right.getRelevantConstraints(
    -            right.constraints
    -              .union(left.constraints)
    -              
.union(splitConjunctivePredicates(newConditionOpt.get).toSet))
    -          val newFilters = inferredConstraints
    -            .filterNot(right.constraints.contains)
    -            .reduceLeftOption(And)
    -          newFilters.map(Filter(_, right))
    -        case _ => None
    -      }
    +      joinType match {
    +        // For inner join, we can infer additional filters for both sides. 
LeftSemi is kind of an
    +        // inner join, it just drops the right side in the final output.
    +        case _: InnerLike | LeftSemi =>
    +          val allConstraints = getAllConstraints(left, right, conditionOpt)
    +          val newLeft = inferNewFilter(left, allConstraints)
    +          val newRight = inferNewFilter(right, allConstraints)
    +          join.copy(left = newLeft, right = newRight)
     
    -      if ((newConditionOpt.isDefined && (newConditionOpt ne conditionOpt))
    -        || newLeftOpt.isDefined || newRightOpt.isDefined) {
    -        Join(newLeftOpt.getOrElse(left), newRightOpt.getOrElse(right), 
joinType, newConditionOpt)
    -      } else {
    -        join
    +        // For right outer join, we can only infer additional filters for 
left side.
    +        case RightOuter =>
    +          val allConstraints = getAllConstraints(left, right, conditionOpt)
    +          val newLeft = inferNewFilter(left, allConstraints)
    +          join.copy(left = newLeft)
    +
    +        // For left join, we can only infer additional filters for right 
side.
    +        case LeftOuter | LeftAnti =>
    +          val allConstraints = getAllConstraints(left, right, conditionOpt)
    +          val newRight = inferNewFilter(right, allConstraints)
    +          join.copy(right = newRight)
    +
    +        case _ => join
           }
       }
    +
    +  private def getAllConstraints(
    +      left: LogicalPlan,
    +      right: LogicalPlan,
    +      conditionOpt: Option[Expression]): Set[Expression] = {
    +    val baseConstraints = left.constraints.union(right.constraints)
    +      
.union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet)
    +    baseConstraints.union(inferAdditionalConstraints(baseConstraints))
    +  }
    +
    +  private def inferNewFilter(plan: LogicalPlan, constraints: 
Set[Expression]): LogicalPlan = {
    +    val newPredicates = constraints
    +      .union(constructIsNotNullConstraints(constraints, plan.output))
    --- End diff --
    
    ```
    case _: InnerLike | LeftSemi =>
              val allConstraints = getAllConstraints(left, right, conditionOpt)
              val newLeft = inferNewFilter(left, allConstraints)
              val newRight = inferNewFilter(right, allConstraints)
              join.copy(left = newLeft, right = newRight)
    ```
    
    For this pattern, if we reuse the code you mentioned, we need to do 
constraints expanding twice, for left and right.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to