Github user aokolnychyi commented on a diff in the pull request: https://github.com/apache/spark/pull/22857#discussion_r229449496 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala --- @@ -736,3 +736,60 @@ object CombineConcats extends Rule[LogicalPlan] { flattenConcats(concat) } } + +/** + * A rule that replaces `Literal(null, _)` with `FalseLiteral` for further optimizations. + * + * This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it transforms predicates + * in all [[If]] expressions as well as branch conditions in all [[CaseWhen]] expressions. + * + * For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`. + * + * Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`; + * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually + * `Filter(FalseLiteral)`. + * + * As this rule is not limited to conditions in [[Filter]] and [[Join]], arbitrary plans can + * benefit from it. For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` + * can be simplified into `Project(Literal(2))`. + * + * As a result, many unnecessary computations can be removed in the query optimization phase. + */ +object ReplaceNullWithFalse extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) + case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond))) + case p: LogicalPlan => p transformExpressions { + case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) + case cw @ CaseWhen(branches, _) => + val newBranches = branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> value + } + cw.copy(branches = newBranches) + } + } + + /** + * Recursively replaces `Literal(null, _)` with `FalseLiteral`. + * + * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit + * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or `Literal(null, _)`. + */ + private def replaceNullWithFalse(e: Expression): Expression = e match { + case cw: CaseWhen if cw.dataType == BooleanType => + val newBranches = cw.branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> replaceNullWithFalse(value) + } + val newElseValue = cw.elseValue.map(replaceNullWithFalse) + CaseWhen(newBranches, newElseValue) + case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => + If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) --- End diff -- Let me know if I got you correctly here
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org