Repository: spark Updated Branches: refs/heads/master 111c05538 -> 3f6d28a5c
[SPARK-9102] [SQL] Improve project collapse with nondeterministic expressions Currently we will stop project collapse when the lower projection has nondeterministic expressions. However it's overkill sometimes, we should be able to optimize `df.select(Rand(10)).select('a)` to `df.select('a)` Author: Wenchen Fan <cloud0...@outlook.com> Closes #7445 from cloud-fan/non-deterministic and squashes the following commits: 0deaef6 [Wenchen Fan] Improve project collapse with nondeterministic expressions Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3f6d28a5 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3f6d28a5 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3f6d28a5 Branch: refs/heads/master Commit: 3f6d28a5ca98cf7d20c2c029094350cc4f9545a0 Parents: 111c055 Author: Wenchen Fan <cloud0...@outlook.com> Authored: Fri Jul 17 00:59:15 2015 -0700 Committer: Yin Huai <yh...@databricks.com> Committed: Fri Jul 17 00:59:15 2015 -0700 ---------------------------------------------------------------------- .../sql/catalyst/optimizer/Optimizer.scala | 38 ++++++++++---------- .../optimizer/ProjectCollapsingSuite.scala | 26 ++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 10 +++--- 3 files changed, 51 insertions(+), 23 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3f6d28a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2f94b45..d5beeec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -206,31 +206,33 @@ object ColumnPruning extends Rule[LogicalPlan] { */ object ProjectCollapsing extends Rule[LogicalPlan] { - /** Returns true if any expression in projectList is non-deterministic. */ - private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = { - projectList.exists(expr => expr.find(!_.deterministic).isDefined) - } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - // We only collapse these two Projects if the child Project's expressions are all - // deterministic. - case Project(projectList1, Project(projectList2, child)) - if !hasNondeterministic(projectList2) => + case p @ Project(projectList1, Project(projectList2, child)) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). val aliasMap = AttributeMap(projectList2.collect { - case a @ Alias(e, _) => (a.toAttribute, a) + case a: Alias => (a.toAttribute, a) }) - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' - // TODO: Fix TransformBase to avoid the cast below. - val substitutedProjection = projectList1.map(_.transform { - case a: Attribute if aliasMap.contains(a) => aliasMap(a) - }).asInstanceOf[Seq[NamedExpression]] + // We only collapse these two Projects if their overlapped expressions are all + // deterministic. + val hasNondeterministic = projectList1.flatMap(_.collect { + case a: Attribute if aliasMap.contains(a) => aliasMap(a).child + }).exists(_.find(!_.deterministic).isDefined) - Project(substitutedProjection, child) + if (hasNondeterministic) { + p + } else { + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + // TODO: Fix TransformBase to avoid the cast below. + val substitutedProjection = projectList1.map(_.transform { + case a: Attribute => aliasMap.getOrElse(a, a) + }).asInstanceOf[Seq[NamedExpression]] + + Project(substitutedProjection, child) + } } } http://git-wip-us.apache.org/repos/asf/spark/blob/3f6d28a5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala index 151654b..1aa8999 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala @@ -70,4 +70,30 @@ class ProjectCollapsingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("collapse two nondeterministic, independent projects into one") { + val query = testRelation + .select(Rand(10).as('rand)) + .select(Rand(20).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation + .select(Rand(20).as('rand2)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse one nondeterministic, one deterministic, independent projects into one") { + val query = testRelation + .select(Rand(10).as('rand), 'a) + .select(('a + 1).as('a_plus_1)) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation + .select(('a + 1).as('a_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/3f6d28a5/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 23244fd..192cc0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -745,8 +745,8 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("SPARK-8072: Better Exception for Duplicate Columns") { // only one duplicate column present val e = intercept[org.apache.spark.sql.AnalysisException] { - val df1 = Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") - .write.format("parquet").save("temp") + Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") + .write.format("parquet").save("temp") } assert(e.getMessage.contains("Duplicate column(s)")) assert(e.getMessage.contains("parquet")) @@ -755,9 +755,9 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { // multiple duplicate columns present val f = intercept[org.apache.spark.sql.AnalysisException] { - val df2 = Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7)) - .toDF("column1", "column2", "column3", "column1", "column3") - .write.format("json").save("temp") + Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7)) + .toDF("column1", "column2", "column3", "column1", "column3") + .write.format("json").save("temp") } assert(f.getMessage.contains("Duplicate column(s)")) assert(f.getMessage.contains("JSON")) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org