Github user gatorsmile commented on a diff in the pull request:

    https://github.com/apache/spark/pull/16954#discussion_r103853428
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
 ---
    @@ -83,29 +95,150 @@ object RewritePredicateSubquery extends 
Rule[LogicalPlan] with PredicateHelper {
       }
     
       /**
    -   * Given a predicate expression and an input plan, it rewrites
    -   * any embedded existential sub-query into an existential join.
    -   * It returns the rewritten expression together with the updated plan.
    -   * Currently, it does not support null-aware joins. Embedded NOT IN 
predicates
    -   * are blocked in the Analyzer.
    +   * Given a predicate expression and an input plan, it rewrites any 
embedded existential sub-query
    +   * into an existential join. It returns the rewritten expression 
together with the updated plan.
    +   * Currently, it does not support NOT IN nested inside a NOT expression. 
This case is blocked in
    +   * the Analyzer.
        */
       private def rewriteExistentialExpr(
           exprs: Seq[Expression],
           plan: LogicalPlan): (Option[Expression], LogicalPlan) = {
         var newPlan = plan
         val newExprs = exprs.map { e =>
           e transformUp {
    -        case PredicateSubquery(sub, conditions, nullAware, _) =>
    -          // TODO: support null-aware join
    +        case Exists(sub, conditions, exprId) =>
               val exists = AttributeReference("exists", BooleanType, nullable 
= false)()
    -          newPlan = Join(newPlan, sub, ExistenceJoin(exists), 
conditions.reduceLeftOption(And))
    +          newPlan = Join(newPlan, sub,
    +            ExistenceJoin(exists), conditions.reduceLeftOption(And))
               exists
    -        }
    +        case In(e, Seq(l@ ListQuery(sub, conditions, exprId))) =>
    +          val exists = AttributeReference("exists", BooleanType, nullable 
= false)()
    +          val inConditions = 
getValueExpression(e).zip(sub.output).map(EqualTo.tupled)
    +          newPlan = Join(newPlan, sub,
    +            ExistenceJoin(exists), (inConditions ++ 
conditions).reduceLeftOption(And))
    +          exists
    +      }
         }
         (newExprs.reduceOption(And), newPlan)
       }
     }
     
    + /**
    +  * Pull out all (outer) correlated predicates from a given subquery. This 
method removes the
    +  * correlated predicates from subquery [[Filter]]s and adds the 
references of these predicates
    +  * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are 
missing) in order to
    +  * be able to evaluate the predicates at the top level.
    +  *
    +  * TODO: Look to merge this rule with RewritePredicateSubquery.
    +  */
    +object PullupCorrelatedPredicates extends Rule[LogicalPlan] with 
PredicateHelper {
    +   /**
    +    * Returns the correlated predicates and a updated plan that removes 
the outer references.
    +    */
    +  private def pullOutCorrelatedPredicates(
    +      sub: LogicalPlan,
    +      outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
    +    val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, 
Seq[Expression]]
    +
    +    /** Determine which correlated predicate references are missing from 
this plan. */
    +    def missingReferences(p: LogicalPlan): AttributeSet = {
    +      val localPredicateReferences = p.collect(predicateMap)
    +        .flatten
    +        .map(_.references)
    +        .reduceOption(_ ++ _)
    +        .getOrElse(AttributeSet.empty)
    +      localPredicateReferences -- p.outputSet
    +    }
    +
    +    // Simplify the predicates before pulling them out.
    +    val transformed = BooleanSimplification(sub) transformUp {
    +      case f @ Filter(cond, child) =>
    +        val (correlated, local) =
    +          
splitConjunctivePredicates(cond).partition(SubExprUtils.containsOuter)
    +
    +        // Rewrite the filter without the correlated predicates if any.
    +        correlated match {
    +          case Nil => f
    +          case xs if local.nonEmpty =>
    +            val newFilter = Filter(local.reduce(And), child)
    +            predicateMap += newFilter -> xs
    +            newFilter
    +          case xs =>
    +            predicateMap += child -> xs
    +            child
    +        }
    +      case p @ Project(expressions, child) =>
    +        val referencesToAdd = missingReferences(p)
    +        if (referencesToAdd.nonEmpty) {
    +          Project(expressions ++ referencesToAdd, child)
    +        } else {
    +          p
    +        }
    +      case a @ Aggregate(grouping, expressions, child) =>
    +        val referencesToAdd = missingReferences(a)
    +        if (referencesToAdd.nonEmpty) {
    +          Aggregate(grouping ++ referencesToAdd, expressions ++ 
referencesToAdd, child)
    +        } else {
    +          a
    +        }
    +      case p =>
    +        p
    +    }
    +
    +    // Make sure the inner and the outer query attributes do not collide.
    +    // In case of a collision, change the subquery plan's output to use
    +    // different attribute by creating alias(s).
    +    val baseConditions = predicateMap.values.flatten.toSeq
    +    val (newplan: LogicalPlan, newcond: Seq[Expression]) = if 
(outer.nonEmpty) {
    +      val outputSet = outer.map(_.outputSet).reduce(_ ++ _)
    +      val duplicates = transformed.outputSet.intersect(outputSet)
    +      val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) {
    +        val aliasMap = AttributeMap(duplicates.map { dup =>
    +          dup -> Alias(dup, dup.toString)()
    +        }.toSeq)
    +        val aliasedExpressions = transformed.output.map { ref =>
    +          aliasMap.getOrElse(ref, ref)
    +        }
    +        val aliasedProjection = Project(aliasedExpressions, transformed)
    +        val aliasedConditions = baseConditions.map(_.transform {
    +          case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute
    +        })
    +        (aliasedProjection, aliasedConditions)
    +      } else {
    +        (transformed, baseConditions)
    +      }
    +      (plan, SubExprUtils.stripOuterReferences(deDuplicatedConditions))
    +    } else {
    +      (transformed, SubExprUtils.stripOuterReferences(baseConditions))
    +    }
    +    (newplan, newcond)
    +  }
    +
    +  private def rewriteSubQueries(plan: LogicalPlan, outerPlans: 
Seq[LogicalPlan]): LogicalPlan = {
    +    plan transformExpressions {
    +      case s @ ScalarSubquery(sub, cond, exprId) if s.children.nonEmpty =>
    +        val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, 
outerPlans)
    +        ScalarSubquery(newPlan, newCond, exprId)
    +      case e @ Exists(sub, cond, exprId) if e.children.nonEmpty =>
    +        val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, 
outerPlans)
    +        Exists(newPlan, newCond, exprId)
    +      case l @ ListQuery(sub, cond, exprId) =>
    --- End diff --
    
    `case ListQuery(sub, _, exprId) =>`


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to