This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 62956c92cfc7 [SPARK-46707][SQL] Added throwable field to expressions 
to improve predicate pushdown
62956c92cfc7 is described below

commit 62956c92cfc74d7523328d168b6d837938cde763
Author: Kelvin Jiang <kelvin.ji...@databricks.com>
AuthorDate: Thu Jan 18 19:25:24 2024 +0800

    [SPARK-46707][SQL] Added throwable field to expressions to improve 
predicate pushdown
    
    ### What changes were proposed in this pull request?
    
    This PR adds the field `throwable` to `Expression`. If an expression is 
marked as throwable, we will avoid pushing filters containing these expressions 
through joins, filters, and aggregations (i.e. operators that filter input).
    
    ### Why are the changes needed?
    
    For predicate pushdown, currently it is possible that we push down a filter 
that ends up being evaluated on more rows than before it was pushed down (e.g. 
if we push the filter through a selective join). In this case, it is possible 
that we now evaluate the filter on a row that will cause a runtime error to be 
thrown, when prior to pushing this would not have happened.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added UTs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44716 from kelvinjian-db/SPARK-46707-throwable.
    
    Authored-by: Kelvin Jiang <kelvin.ji...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/expressions/Expression.scala      |  5 ++
 .../expressions/collectionOperations.scala         |  3 ++
 .../spark/sql/catalyst/optimizer/Optimizer.scala   | 27 +++++-----
 .../catalyst/optimizer/FilterPushdownSuite.scala   | 63 ++++++++++++++++++++++
 4 files changed, 84 insertions(+), 14 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 2cc813bd3055..484418f5e5a7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -140,6 +140,11 @@ abstract class Expression extends TreeNode[Expression] {
    */
   def stateful: Boolean = false
 
+  /**
+   * Returns true if the expression could potentially throw an exception when 
evaluated.
+   */
+  lazy val throwable: Boolean = children.exists(_.throwable)
+
   /**
    * Returns a copy of this expression where all stateful expressions are 
replaced with fresh
    * uninitialized copies. If the expression contains no stateful expressions 
then the original
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 04f56eaf8c1e..5aa96dd1a6aa 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -2983,6 +2983,9 @@ case class Sequence(
 
   override def nullable: Boolean = children.exists(_.nullable)
 
+  // If step is defined, then an error will be thrown if the start and stop do 
not satisfy the step.
+  override lazy val throwable: Boolean = stepOpt.isDefined
+
   override def dataType: ArrayType = ArrayType(start.dataType, containsNull = 
false)
 
   override def checkInputDataTypes(): TypeCheckResult = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 8fcc7c7c26b4..4186c8c1db91 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1549,10 +1549,11 @@ object CombineFilters extends Rule[LogicalPlan] with 
PredicateHelper {
 
   val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
     // The query execution/optimization does not guarantee the expressions are 
evaluated in order.
-    // We only can combine them if and only if both are deterministic.
+    // We only can combine them if and only if both are deterministic and the 
outer condition is not
+    // throwable (inner can be throwable as it was going to be evaluated first 
anyways).
     case Filter(fc, nf @ Filter(nc, grandChild)) if nc.deterministic =>
-      val (combineCandidates, nonDeterministic) =
-        splitConjunctivePredicates(fc).partition(_.deterministic)
+      val (combineCandidates, rest) =
+        splitConjunctivePredicates(fc).partition(p => p.deterministic && 
!p.throwable)
       val mergedFilter = (ExpressionSet(combineCandidates) --
         ExpressionSet(splitConjunctivePredicates(nc))).reduceOption(And) match 
{
         case Some(ac) =>
@@ -1560,7 +1561,7 @@ object CombineFilters extends Rule[LogicalPlan] with 
PredicateHelper {
         case None =>
           nf
       }
-      nonDeterministic.reduceOption(And).map(c => Filter(c, 
mergedFilter)).getOrElse(mergedFilter)
+      rest.reduceOption(And).map(c => Filter(c, 
mergedFilter)).getOrElse(mergedFilter)
   }
 }
 
@@ -1730,16 +1731,12 @@ object PushPredicateThroughNonJoin extends 
Rule[LogicalPlan] with PredicateHelpe
 
       // For each filter, expand the alias and check if the filter can be 
evaluated using
       // attributes produced by the aggregate operator's child operator.
-      val (candidates, nonDeterministic) =
-        splitConjunctivePredicates(condition).partition(_.deterministic)
-
-      val (pushDown, rest) = candidates.partition { cond =>
+      val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition 
{ cond =>
         val replaced = replaceAlias(cond, aliasMap)
-        cond.references.nonEmpty && 
replaced.references.subsetOf(aggregate.child.outputSet)
+        cond.deterministic && !cond.throwable &&
+          cond.references.nonEmpty && 
replaced.references.subsetOf(aggregate.child.outputSet)
       }
 
-      val stayUp = rest ++ nonDeterministic
-
       if (pushDown.nonEmpty) {
         val pushDownPredicate = pushDown.reduce(And)
         val replaced = replaceAlias(pushDownPredicate, aliasMap)
@@ -1904,13 +1901,14 @@ object PushPredicateThroughJoin extends 
Rule[LogicalPlan] with PredicateHelper {
    * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth)
    */
   private def split(condition: Seq[Expression], left: LogicalPlan, right: 
LogicalPlan) = {
-    val (pushDownCandidates, nonDeterministic) = 
condition.partition(_.deterministic)
+    val (pushDownCandidates, stayUp) =
+      condition.partition(cond => cond.deterministic && !cond.throwable)
     val (leftEvaluateCondition, rest) =
       pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
     val (rightEvaluateCondition, commonCondition) =
         rest.partition(expr => expr.references.subsetOf(right.outputSet))
 
-    (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ 
nonDeterministic)
+    (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ stayUp)
   }
 
   private def canPushThrough(joinType: JoinType): Boolean = joinType match {
@@ -1933,8 +1931,9 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] 
with PredicateHelper {
             reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
           val newRight = rightFilterConditions.
             reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+          // don't push throwable expressions into join condition
           val (newJoinConditions, others) =
-            commonFilterCondition.partition(canEvaluateWithinJoin)
+            commonFilterCondition.partition(cond => 
canEvaluateWithinJoin(cond) && !cond.throwable)
           val newJoinCond = (newJoinConditions ++ 
joinCondition).reduceLeftOption(And)
 
           val join = Join(newLeft, newRight, joinType, newJoinCond, hint)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 2ebb43d4fba3..bd2ac28a049f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -1433,4 +1433,67 @@ class FilterPushdownSuite extends PlanTest {
     val correctAnswer = RebalancePartitions(Seq.empty, testRelation.where($"a" 
> 3)).analyze
     comparePlans(optimized, correctAnswer)
   }
+
+  test("SPARK-46707: push down predicate with sequence (without step) through 
joins") {
+    val x = testRelation.subquery("x")
+    val y = testRelation1.subquery("y")
+
+    // do not push down when sequence has step param
+    val queryWithStep = x.join(y, joinType = Inner, condition = Some($"x.c" 
=== $"y.d"))
+      .where(IsNotNull(Sequence($"x.a", $"x.b", Some(Literal(1)))))
+      .analyze
+    val optimizedQueryWithStep = Optimize.execute(queryWithStep)
+    comparePlans(optimizedQueryWithStep, queryWithStep)
+
+    // push down when sequence does not have step param
+    val queryWithoutStep = x.join(y, joinType = Inner, condition = Some($"x.c" 
=== $"y.d"))
+      .where(IsNotNull(Sequence($"x.a", $"x.b", None)))
+      .analyze
+    val optimizedQueryWithoutStep = Optimize.execute(queryWithoutStep)
+    val correctAnswer = x.where(IsNotNull(Sequence($"x.a", $"x.b", None)))
+      .join(y, joinType = Inner, condition = Some($"x.c" === $"y.d"))
+      .analyze
+    comparePlans(optimizedQueryWithoutStep, correctAnswer)
+  }
+
+  test("SPARK-46707: push down predicate with sequence (without step) through 
aggregates") {
+    val x = testRelation.subquery("x")
+
+    // do not push down when sequence has step param
+    val queryWithStep = x.groupBy($"x.a", $"x.b")($"x.a", $"x.b")
+      .where(IsNotNull(Sequence($"x.a", $"x.b", Some(Literal(1)))))
+      .analyze
+    val optimizedQueryWithStep = Optimize.execute(queryWithStep)
+    comparePlans(optimizedQueryWithStep, queryWithStep)
+
+    // push down when sequence does not have step param
+    val queryWithoutStep = x.groupBy($"x.a", $"x.b")($"x.a", $"x.b")
+      .where(IsNotNull(Sequence($"x.a", $"x.b", None)))
+      .analyze
+    val optimizedQueryWithoutStep = Optimize.execute(queryWithoutStep)
+    val correctAnswer = x.where(IsNotNull(Sequence($"x.a", $"x.b", None)))
+      .groupBy($"x.a", $"x.b")($"x.a", $"x.b")
+      .analyze
+    comparePlans(optimizedQueryWithoutStep, correctAnswer)
+  }
+
+  test("SPARK-46707: combine predicate with sequence (without step) with other 
filters") {
+    val x = testRelation.subquery("x")
+
+    // do not combine when sequence has step param
+    val queryWithStep = x.where($"x.c" > 1)
+      .where(IsNotNull(Sequence($"x.a", $"x.b", Some(Literal(1)))))
+      .analyze
+    val optimizedQueryWithStep = Optimize.execute(queryWithStep)
+    comparePlans(optimizedQueryWithStep, queryWithStep)
+
+    // combine when sequence does not have step param
+    val queryWithoutStep = x.where($"x.c" > 1)
+      .where(IsNotNull(Sequence($"x.a", $"x.b", None)))
+      .analyze
+    val optimizedQueryWithoutStep = Optimize.execute(queryWithoutStep)
+    val correctAnswer = x.where(IsNotNull(Sequence($"x.a", $"x.b", None)) && 
$"x.c" > 1)
+      .analyze
+    comparePlans(optimizedQueryWithoutStep, correctAnswer)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to