Repository: spark Updated Branches: refs/heads/master a35a67a83 -> 6e6320122
[SPARK-14830][SQL] Add RemoveRepetitionFromGroupExpressions optimizer. ## What changes were proposed in this pull request? This PR aims to optimize GroupExpressions by removing repeating expressions. `RemoveRepetitionFromGroupExpressions` is added. **Before** ```scala scala> sql("select a+1 from values 1,2 T(a) group by a+1, 1+a, A+1, 1+A").explain() == Physical Plan == WholeStageCodegen : +- TungstenAggregate(key=[(a#0 + 1)#6,(1 + a#0)#7,(A#0 + 1)#8,(1 + A#0)#9], functions=[], output=[(a + 1)#5]) : +- INPUT +- Exchange hashpartitioning((a#0 + 1)#6, (1 + a#0)#7, (A#0 + 1)#8, (1 + A#0)#9, 200), None +- WholeStageCodegen : +- TungstenAggregate(key=[(a#0 + 1) AS (a#0 + 1)#6,(1 + a#0) AS (1 + a#0)#7,(A#0 + 1) AS (A#0 + 1)#8,(1 + A#0) AS (1 + A#0)#9], functions=[], output=[(a#0 + 1)#6,(1 + a#0)#7,(A#0 + 1)#8,(1 + A#0)#9]) : +- INPUT +- LocalTableScan [a#0], [[1],[2]] ``` **After** ```scala scala> sql("select a+1 from values 1,2 T(a) group by a+1, 1+a, A+1, 1+A").explain() == Physical Plan == WholeStageCodegen : +- TungstenAggregate(key=[(a#0 + 1)#6], functions=[], output=[(a + 1)#5]) : +- INPUT +- Exchange hashpartitioning((a#0 + 1)#6, 200), None +- WholeStageCodegen : +- TungstenAggregate(key=[(a#0 + 1) AS (a#0 + 1)#6], functions=[], output=[(a#0 + 1)#6]) : +- INPUT +- LocalTableScan [a#0], [[1],[2]] ``` ## How was this patch tested? Pass the Jenkins tests (with a new testcase) Author: Dongjoon Hyun <dongj...@apache.org> Closes #12590 from dongjoon-hyun/SPARK-14830. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6e632012 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6e632012 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6e632012 Branch: refs/heads/master Commit: 6e6320122ea84247c67e2d0fb0e6af54e2c5bb31 Parents: a35a67a Author: Dongjoon Hyun <dongj...@apache.org> Authored: Mon May 2 12:40:21 2016 -0700 Committer: Michael Armbrust <mich...@databricks.com> Committed: Mon May 2 12:40:21 2016 -0700 ---------------------------------------------------------------------- .../sql/catalyst/optimizer/Optimizer.scala | 15 ++++++++++++++- .../optimizer/AggregateOptimizeSuite.scala | 20 +++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6e632012/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0b70ede..a147fff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -68,7 +68,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) ReplaceExceptWithAntiJoin, ReplaceDistinctWithAggregate) :: Batch("Aggregate", fixedPoint, - RemoveLiteralFromGroupExpressions) :: + RemoveLiteralFromGroupExpressions, + RemoveRepetitionFromGroupExpressions) :: Batch("Operator Optimizations", fixedPoint, // Operator push down SetOperationPushDown, @@ -1440,6 +1441,18 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { } /** + * Removes repetition from group expressions in [[Aggregate]], as they have no effect to the result + * but only makes the grouping key bigger. + */ +object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(grouping, _, _) => + val newGrouping = ExpressionSet(grouping).toSeq + a.copy(groupingExpressions = newGrouping) + } +} + +/** * Computes the current date and time to make sure we return the same result in a single query. */ object ComputeCurrentTime extends Rule[LogicalPlan] { http://git-wip-us.apache.org/repos/asf/spark/blob/6e632012/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index e458eb8..c94dcb3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal @@ -25,10 +28,14 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor class AggregateOptimizeSuite extends PlanTest { + val conf = new SimpleCatalystConf(caseSensitiveAnalysis = false) + val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), - RemoveLiteralFromGroupExpressions) :: Nil + RemoveLiteralFromGroupExpressions, + RemoveRepetitionFromGroupExpressions) :: Nil } test("remove literals in grouping expression") { @@ -42,4 +49,15 @@ class AggregateOptimizeSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("remove repetition in grouping expression") { + val input = LocalRelation('a.int, 'b.int, 'c.int) + + val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) + val optimized = Optimize.execute(analyzer.execute(query)) + + val correctAnswer = analyzer.execute(input.groupBy('a + 1, 'b + 2)(sum('c))) + + comparePlans(optimized, correctAnswer) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org