Github user frreiss commented on a diff in the pull request:

    https://github.com/apache/spark/pull/13155#discussion_r66564793
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 ---
    @@ -1695,16 +1695,176 @@ object RewriteCorrelatedScalarSubquery extends 
Rule[LogicalPlan] {
       }
     
       /**
    +   * Statically evaluate an expression containing zero or more 
placeholders, given a set
    +   * of bindings for placeholder values.
    +   */
    +  private def evalExpr(expr : Expression, bindings : Map[Long, 
Option[Any]]) : Option[Any] = {
    +    val rewrittenExpr = expr transform {
    +      case r @ AttributeReference(_, dataType, _, _) =>
    +        bindings(r.exprId.id) match {
    +          case Some(v) => Literal.create(v, dataType)
    +          case None => Literal.default(NullType)
    +        }
    +    }
    +    Option(rewrittenExpr.eval())
    +  }
    +
    +  /**
    +   * Statically evaluate an expression containing one or more aggregates 
on an empty input.
    +   */
    +  private def evalAggOnZeroTups(expr : Expression) : Option[Any] = {
    +    // AggregateExpressions are Unevaluable, so we need to replace all 
aggregates
    +    // in the expression with the value they would return for zero input 
tuples.
    +    val rewrittenExpr = expr transform {
    +      case a @ AggregateExpression(aggFunc, _, _, resultId) =>
    +        aggFunc.defaultResult.getOrElse(Literal.default(NullType))
    +    }
    +    Option(rewrittenExpr.eval())
    +  }
    +
    +  /**
    +   * Statically evaluate a scalar subquery on an empty input.
    +   *
    +   * <b>WARNING:</b> This method only covers subqueries that pass the 
checks under
    +   * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the 
checks in
    +   * CheckAnalysis become less restrictive, this method will need to 
change.
    +   */
    +  private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = {
    +    // Inputs to this method will start with a chain of zero or more 
SubqueryAlias
    +    // and Project operators, followed by an optional Filter, followed by 
an
    +    // Aggregate. Traverse the operators recursively.
    +    def evalPlan(lp : LogicalPlan) : Map[Long, Option[Any]] = {
    +      lp match {
    +        case SubqueryAlias(_, child) => evalPlan(child)
    +        case Filter(condition, child) =>
    +          val bindings = evalPlan(child)
    +          if (bindings.size == 0) bindings
    +          else {
    +            val exprResult = evalExpr(condition, bindings).getOrElse(false)
    +              .asInstanceOf[Boolean]
    +            if (exprResult) bindings else Map()
    +          }
    +
    +        case Project(projectList, child) =>
    +          val bindings = evalPlan(child)
    +          if (bindings.size == 0) {
    +            bindings
    +          } else {
    +            projectList.map(ne => (ne.exprId.id, evalExpr(ne, 
bindings))).toMap
    +          }
    +
    +        case Aggregate(_, aggExprs, _) =>
    +          // Some of the expressions under the Aggregate node are the join 
columns
    +          // for joining with the outer query block. Fill those 
expressions in with
    +          // nulls and statically evaluate the remainder.
    +          aggExprs.map(ne => ne match {
    +            case AttributeReference(_, _, _, _) => (ne.exprId.id, None)
    +            case Alias(AttributeReference(_, _, _, _), _) => 
(ne.exprId.id, None)
    +            case _ => (ne.exprId.id, evalAggOnZeroTups(ne))
    +          }).toMap
    +
    +        case _ => sys.error(s"Unexpected operator in scalar subquery: $lp")
    +      }
    +    }
    +
    +    val resultMap = evalPlan(plan)
    +
    +    // By convention, the scalar subquery result is the leftmost field.
    +    resultMap(plan.output.head.exprId.id)
    +  }
    +
    +  /**
    +   * Split the plan for a scalar subquery into the parts above the 
Aggregate node
    +   * (first part of returned value) and the parts below the Aggregate 
node, including
    +   * the Aggregate (second part of returned value)
    +   */
    +  private def splitSubquery(plan : LogicalPlan) : Tuple2[Seq[LogicalPlan], 
Aggregate] = {
    +    var topPart = List[LogicalPlan]()
    +    var bottomPart : LogicalPlan = plan
    +    while (! bottomPart.isInstanceOf[Aggregate]) {
    +      topPart = bottomPart :: topPart
    +      bottomPart = bottomPart.children.head
    +    }
    +    (topPart, bottomPart.asInstanceOf[Aggregate])
    +  }
    +
    +  /**
    +   * Rewrite the nodes above the Aggregate in a subquery so that they 
generate an
    +   * auxiliary column "isFiltered"
    +   * @param subqueryPlan plan before rewrite
    +   * @param filteredId expression ID for the "isFiltered" column
    +   */
    +  private def addIsFiltered(subqueryPlan : LogicalPlan, filteredId : 
ExprId) : LogicalPlan = {
    +    val isFilteredRef = AttributeReference("isFiltered", 
BooleanType)(exprId = filteredId)
    +    val (topPart, aggNode) = splitSubquery(subqueryPlan)
    +    var rewrittenQuery: LogicalPlan = null
    +    if (topPart.size > 0 && topPart.head.isInstanceOf[Filter]) {
    +      // Correlated subquery has a HAVING clause
    +      // Rewrite the Filter into a Project that returns the value of the 
filtering predicate
    +      val origFilter = topPart.head.asInstanceOf[Filter]
    +      var topRemainder = topPart.tail
    +      val newProjectList =
    +        origFilter.output :+ Alias(origFilter.condition, 
"isFiltered")(exprId = filteredId)
    +      val filterAsProject = Project(newProjectList, origFilter.child)
    +
    +      rewrittenQuery = filterAsProject
    +      while (topRemainder.size > 0) {
    +        rewrittenQuery = topRemainder.head match {
    +          case Project(origList, _) => Project(origList :+ isFilteredRef, 
rewrittenQuery)
    +          case SubqueryAlias(alias, _) => SubqueryAlias(alias, 
rewrittenQuery)
    +        }
    +        topRemainder = topRemainder.tail
    +      }
    +    } else {
    +      // Correlated subquery without HAVING clause
    +      // Add an additional Project that adds a constant value for 
"isFiltered"
    +      rewrittenQuery = Project(subqueryPlan.output :+ 
Alias(Literal(false), "isFiltered")
    +      (exprId = filteredId), subqueryPlan)
    +    }
    +    return rewrittenQuery
    +  }
    +
    +  /**
        * Construct a new child plan by left joining the given subqueries to a 
base plan.
        */
       private def constructLeftJoins(
           child: LogicalPlan,
           subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
         subqueries.foldLeft(child) {
           case (currentChild, ScalarSubquery(query, conditions, _)) =>
    -        Project(
    -          currentChild.output :+ query.output.head,
    -          Join(currentChild, query, LeftOuter, 
conditions.reduceOption(And)))
    +        val origOutput = query.output.head
    +
    +        val resultWithZeroTups = evalSubqueryOnZeroTups(query)
    +        if (resultWithZeroTups.isEmpty) {
    +          Project(
    +            currentChild.output :+ origOutput,
    +            Join(currentChild, query, LeftOuter, 
conditions.reduceOption(And)))
    +        } else {
    --- End diff --
    
    @hvanhovell the conditions for the COUNT bug in this rule are actually a 
bit less strict. If the correlated subquery returns an answer other than NULL 
when the correlation bindings do not join, then it may be possible to have the 
COUNT bug. Take this query for example:
    ```sql
    select l.a from l where
            (select case when count(*) = 1 then null else count(*) end as cnt
            from r where l.a = r.c) = 0
    ```
    This subquery returns NULL when exactly 1 tuple from `r` joins with a set 
of correlation bindings and a non-NULL value otherwise. Rewrite this query 
without compensating for the COUNT bug and you get:
    ```sql
    select l.a from l left join 
      (select r.c, case when count(*) = 1 then null else count(*) end as cnt 
from r group by r.c) SQ
    on l.a = SQ.c
    where SQ.cnt = 0
    ```
    If the value of `l.a` matches zero tuples in `r`, the outer join produces a 
tuple with `cnt` set to NULL and the rewritten query incorrectly returns an 
empty result.
    
    Technically the full condition for the COUNT bug is that the subquery 
returns a non-NULL value when zero tuples join; *and* the outer query block 
will return a different answer if it receives that value instead of NULL. 
Checking the outer query block would be a somewhat-useful performance 
optimization, but it would require additional static evaluation code.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to