Repository: spark Updated Branches: refs/heads/branch-2.1 77911201c -> fc242ccf4
[SPARK-20246][SQL] should not push predicate down through aggregate with non-deterministic expressions ## What changes were proposed in this pull request? Similar to `Project`, when `Aggregate` has non-deterministic expressions, we should not push predicate down through it, as it will change the number of input rows and thus change the evaluation result of non-deterministic expressions in `Aggregate`. ## How was this patch tested? new regression test Author: Wenchen Fan <wenc...@databricks.com> Closes #17562 from cloud-fan/filter. (cherry picked from commit 7577e9c356b580d744e1fc27c645fce41bdf9cf0) Signed-off-by: Xiao Li <gatorsm...@gmail.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fc242ccf Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fc242ccf Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fc242ccf Branch: refs/heads/branch-2.1 Commit: fc242ccf4b8c8e3e3932a814281ebeb14302f0d2 Parents: 7791120 Author: Wenchen Fan <wenc...@databricks.com> Authored: Fri Apr 7 20:54:18 2017 -0700 Committer: Xiao Li <gatorsm...@gmail.com> Committed: Fri Apr 7 20:54:31 2017 -0700 ---------------------------------------------------------------------- .../sql/catalyst/optimizer/Optimizer.scala | 60 ++++++++++---------- .../optimizer/FilterPushdownSuite.scala | 41 +++++++++++-- 2 files changed, 68 insertions(+), 33 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/fc242ccf/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1ca4dba..291a0c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -742,7 +742,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // implies that, for a given input row, the output are determined by the expression's initial // state and all the input rows processed before. In another word, the order of input rows // matters for non-deterministic expressions, while pushing down predicates changes the order. - case filter @ Filter(condition, project @ Project(fields, grandChild)) + // This also applies to Aggregate. + case Filter(condition, project @ Project(fields, grandChild)) if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) => // Create a map of Aliases to their values from the child projection. @@ -753,33 +754,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) - // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be - // pushed beneath must satisfy the following conditions: - // 1. All the expressions are part of window partitioning key. The expressions can be compound. - // 2. Deterministic. - // 3. Placed before any non-deterministic predicates. - case filter @ Filter(condition, w: Window) - if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => - val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) - - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) - - val (pushDown, rest) = candidates.partition { cond => - cond.references.subsetOf(partitionAttrs) - } - - val stayUp = rest ++ containingNonDeterministic - - if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val newWindow = w.copy(child = Filter(pushDownPredicate, w.child)) - if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow) - } else { - filter - } - - case filter @ Filter(condition, aggregate: Aggregate) => + case filter @ Filter(condition, aggregate: Aggregate) + if aggregate.aggregateExpressions.forall(_.deterministic) => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { @@ -810,6 +786,32 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } + // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be + // pushed beneath must satisfy the following conditions: + // 1. All the expressions are part of window partitioning key. The expressions can be compound. + // 2. Deterministic. + // 3. Placed before any non-deterministic predicates. + case filter @ Filter(condition, w: Window) + if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) + + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(partitionAttrs) + } + + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val newWindow = w.copy(child = Filter(pushDownPredicate, w.child)) + if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow) + } else { + filter + } + case filter @ Filter(condition, union: Union) => // Union could change the rows, so non-deterministic predicate can't be pushed down val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic) @@ -835,7 +837,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } - case filter @ Filter(condition, u: UnaryNode) + case filter @ Filter(_, u: UnaryNode) if canPushThrough(u) && u.expressions.forall(_.deterministic) => pushDownPredicate(filter, u.child) { predicate => u.withNewChildren(Seq(Filter(predicate, u.child))) http://git-wip-us.apache.org/repos/asf/spark/blob/fc242ccf/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 6feea40..150ebd2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -134,15 +134,20 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("nondeterministic: can't push down filter with nondeterministic condition through project") { + test("nondeterministic: can always push down filter through project with deterministic field") { val originalQuery = testRelation - .select(Rand(10).as('rand), 'a) - .where('rand > 5 || 'a > 5) + .select('a) + .where(Rand(10) > 5 || 'a > 5) .analyze val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, originalQuery) + val correctAnswer = testRelation + .where(Rand(10) > 5 || 'a > 5) + .select('a) + .analyze + + comparePlans(optimized, correctAnswer) } test("nondeterministic: can't push down filter through project with nondeterministic field") { @@ -156,6 +161,34 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } + test("nondeterministic: can't push down filter through aggregate with nondeterministic field") { + val originalQuery = testRelation + .groupBy('a)('a, Rand(10).as('rand)) + .where('a > 5) + .analyze + + val optimized = Optimize.execute(originalQuery) + + comparePlans(optimized, originalQuery) + } + + test("nondeterministic: push down part of filter through aggregate with deterministic field") { + val originalQuery = testRelation + .groupBy('a)('a) + .where('a > 5 && Rand(10) > 5) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .where('a > 5) + .groupBy('a)('a) + .where(Rand(10) > 5) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("filters: combines filters") { val originalQuery = testRelation .select('a) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org