This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 62956c92cfc7 [SPARK-46707][SQL] Added throwable field to expressions to improve predicate pushdown 62956c92cfc7 is described below commit 62956c92cfc74d7523328d168b6d837938cde763 Author: Kelvin Jiang <kelvin.ji...@databricks.com> AuthorDate: Thu Jan 18 19:25:24 2024 +0800 [SPARK-46707][SQL] Added throwable field to expressions to improve predicate pushdown ### What changes were proposed in this pull request? This PR adds the field `throwable` to `Expression`. If an expression is marked as throwable, we will avoid pushing filters containing these expressions through joins, filters, and aggregations (i.e. operators that filter input). ### Why are the changes needed? For predicate pushdown, currently it is possible that we push down a filter that ends up being evaluated on more rows than before it was pushed down (e.g. if we push the filter through a selective join). In this case, it is possible that we now evaluate the filter on a row that will cause a runtime error to be thrown, when prior to pushing this would not have happened. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added UTs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44716 from kelvinjian-db/SPARK-46707-throwable. Authored-by: Kelvin Jiang <kelvin.ji...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/expressions/Expression.scala | 5 ++ .../expressions/collectionOperations.scala | 3 ++ .../spark/sql/catalyst/optimizer/Optimizer.scala | 27 +++++----- .../catalyst/optimizer/FilterPushdownSuite.scala | 63 ++++++++++++++++++++++ 4 files changed, 84 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 2cc813bd3055..484418f5e5a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -140,6 +140,11 @@ abstract class Expression extends TreeNode[Expression] { */ def stateful: Boolean = false + /** + * Returns true if the expression could potentially throw an exception when evaluated. + */ + lazy val throwable: Boolean = children.exists(_.throwable) + /** * Returns a copy of this expression where all stateful expressions are replaced with fresh * uninitialized copies. If the expression contains no stateful expressions then the original diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 04f56eaf8c1e..5aa96dd1a6aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2983,6 +2983,9 @@ case class Sequence( override def nullable: Boolean = children.exists(_.nullable) + // If step is defined, then an error will be thrown if the start and stop do not satisfy the step. + override lazy val throwable: Boolean = stepOpt.isDefined + override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false) override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8fcc7c7c26b4..4186c8c1db91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1549,10 +1549,11 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { // The query execution/optimization does not guarantee the expressions are evaluated in order. - // We only can combine them if and only if both are deterministic. + // We only can combine them if and only if both are deterministic and the outer condition is not + // throwable (inner can be throwable as it was going to be evaluated first anyways). case Filter(fc, nf @ Filter(nc, grandChild)) if nc.deterministic => - val (combineCandidates, nonDeterministic) = - splitConjunctivePredicates(fc).partition(_.deterministic) + val (combineCandidates, rest) = + splitConjunctivePredicates(fc).partition(p => p.deterministic && !p.throwable) val mergedFilter = (ExpressionSet(combineCandidates) -- ExpressionSet(splitConjunctivePredicates(nc))).reduceOption(And) match { case Some(ac) => @@ -1560,7 +1561,7 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { case None => nf } - nonDeterministic.reduceOption(And).map(c => Filter(c, mergedFilter)).getOrElse(mergedFilter) + rest.reduceOption(And).map(c => Filter(c, mergedFilter)).getOrElse(mergedFilter) } } @@ -1730,16 +1731,12 @@ object PushPredicateThroughNonJoin extends Rule[LogicalPlan] with PredicateHelpe // For each filter, expand the alias and check if the filter can be evaluated using // attributes produced by the aggregate operator's child operator. - val (candidates, nonDeterministic) = - splitConjunctivePredicates(condition).partition(_.deterministic) - - val (pushDown, rest) = candidates.partition { cond => + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => val replaced = replaceAlias(cond, aliasMap) - cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) + cond.deterministic && !cond.throwable && + cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) } - val stayUp = rest ++ nonDeterministic - if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) val replaced = replaceAlias(pushDownPredicate, aliasMap) @@ -1904,13 +1901,14 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { - val (pushDownCandidates, nonDeterministic) = condition.partition(_.deterministic) + val (pushDownCandidates, stayUp) = + condition.partition(cond => cond.deterministic && !cond.throwable) val (leftEvaluateCondition, rest) = pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) val (rightEvaluateCondition, commonCondition) = rest.partition(expr => expr.references.subsetOf(right.outputSet)) - (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic) + (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ stayUp) } private def canPushThrough(joinType: JoinType): Boolean = joinType match { @@ -1933,8 +1931,9 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) val newRight = rightFilterConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + // don't push throwable expressions into join condition val (newJoinConditions, others) = - commonFilterCondition.partition(canEvaluateWithinJoin) + commonFilterCondition.partition(cond => canEvaluateWithinJoin(cond) && !cond.throwable) val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And) val join = Join(newLeft, newRight, joinType, newJoinCond, hint) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 2ebb43d4fba3..bd2ac28a049f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -1433,4 +1433,67 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = RebalancePartitions(Seq.empty, testRelation.where($"a" > 3)).analyze comparePlans(optimized, correctAnswer) } + + test("SPARK-46707: push down predicate with sequence (without step) through joins") { + val x = testRelation.subquery("x") + val y = testRelation1.subquery("y") + + // do not push down when sequence has step param + val queryWithStep = x.join(y, joinType = Inner, condition = Some($"x.c" === $"y.d")) + .where(IsNotNull(Sequence($"x.a", $"x.b", Some(Literal(1))))) + .analyze + val optimizedQueryWithStep = Optimize.execute(queryWithStep) + comparePlans(optimizedQueryWithStep, queryWithStep) + + // push down when sequence does not have step param + val queryWithoutStep = x.join(y, joinType = Inner, condition = Some($"x.c" === $"y.d")) + .where(IsNotNull(Sequence($"x.a", $"x.b", None))) + .analyze + val optimizedQueryWithoutStep = Optimize.execute(queryWithoutStep) + val correctAnswer = x.where(IsNotNull(Sequence($"x.a", $"x.b", None))) + .join(y, joinType = Inner, condition = Some($"x.c" === $"y.d")) + .analyze + comparePlans(optimizedQueryWithoutStep, correctAnswer) + } + + test("SPARK-46707: push down predicate with sequence (without step) through aggregates") { + val x = testRelation.subquery("x") + + // do not push down when sequence has step param + val queryWithStep = x.groupBy($"x.a", $"x.b")($"x.a", $"x.b") + .where(IsNotNull(Sequence($"x.a", $"x.b", Some(Literal(1))))) + .analyze + val optimizedQueryWithStep = Optimize.execute(queryWithStep) + comparePlans(optimizedQueryWithStep, queryWithStep) + + // push down when sequence does not have step param + val queryWithoutStep = x.groupBy($"x.a", $"x.b")($"x.a", $"x.b") + .where(IsNotNull(Sequence($"x.a", $"x.b", None))) + .analyze + val optimizedQueryWithoutStep = Optimize.execute(queryWithoutStep) + val correctAnswer = x.where(IsNotNull(Sequence($"x.a", $"x.b", None))) + .groupBy($"x.a", $"x.b")($"x.a", $"x.b") + .analyze + comparePlans(optimizedQueryWithoutStep, correctAnswer) + } + + test("SPARK-46707: combine predicate with sequence (without step) with other filters") { + val x = testRelation.subquery("x") + + // do not combine when sequence has step param + val queryWithStep = x.where($"x.c" > 1) + .where(IsNotNull(Sequence($"x.a", $"x.b", Some(Literal(1))))) + .analyze + val optimizedQueryWithStep = Optimize.execute(queryWithStep) + comparePlans(optimizedQueryWithStep, queryWithStep) + + // combine when sequence does not have step param + val queryWithoutStep = x.where($"x.c" > 1) + .where(IsNotNull(Sequence($"x.a", $"x.b", None))) + .analyze + val optimizedQueryWithoutStep = Optimize.execute(queryWithoutStep) + val correctAnswer = x.where(IsNotNull(Sequence($"x.a", $"x.b", None)) && $"x.c" > 1) + .analyze + comparePlans(optimizedQueryWithoutStep, correctAnswer) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org