Repository: spark Updated Branches: refs/heads/master f4772fd26 -> 5f3441e54
[SPARK-24893][SQL] Remove the entire CaseWhen if all the outputs are semantic equivalence ## What changes were proposed in this pull request? Similar to SPARK-24890, if all the outputs of `CaseWhen` are semantic equivalence, `CaseWhen` can be removed. ## How was this patch tested? Tests added. Author: DB Tsai <d_t...@apple.com> Closes #21852 from dbtsai/short-circuit-when. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5f3441e5 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5f3441e5 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5f3441e5 Branch: refs/heads/master Commit: 5f3441e542bfacd81d70bd8b34c22044c8928bff Parents: f4772fd Author: DB Tsai <d_t...@apple.com> Authored: Wed Aug 1 10:31:02 2018 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Wed Aug 1 10:31:02 2018 +0800 ---------------------------------------------------------------------- .../sql/catalyst/optimizer/expressions.scala | 18 ++++++++ .../optimizer/SimplifyConditionalSuite.scala | 48 +++++++++++++++++++- 2 files changed, 64 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/5f3441e5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 4696699..e7b4730 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -416,6 +416,24 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { // these branches can be pruned away val (h, t) = branches.span(_._1 != TrueLiteral) CaseWhen( h :+ t.head, None) + + case e @ CaseWhen(branches, Some(elseValue)) + if branches.forall(_._2.semanticEquals(elseValue)) => + // For non-deterministic conditions with side effect, we can not remove it, or change + // the ordering. As a result, we try to remove the deterministic conditions from the tail. + var hitNonDeterministicCond = false + var i = branches.length + while (i > 0 && !hitNonDeterministicCond) { + hitNonDeterministicCond = !branches(i - 1)._1.deterministic + if (!hitNonDeterministicCond) { + i -= 1 + } + } + if (i == 0) { + elseValue + } else { + e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) + } } } } http://git-wip-us.apache.org/repos/asf/spark/blob/5f3441e5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index e210874..8ad7c12 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} @@ -46,7 +45,9 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val unreachableBranch = (FalseLiteral, Literal(20)) private val nullBranch = (Literal.create(null, NullType), Literal(30)) - private val testRelation = LocalRelation('a.int) + val isNotNullCond = IsNotNull(UnresolvedAttribute(Seq("a"))) + val isNullCond = IsNull(UnresolvedAttribute("b")) + val notCond = Not(UnresolvedAttribute("c")) test("simplify if") { assertEquivalent( @@ -122,4 +123,47 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { None), CaseWhen(normalBranch :: trueBranch :: Nil, None)) } + + test("simplify CaseWhen if all the outputs are semantic equivalence") { + // When the conditions in `CaseWhen` are all deterministic, `CaseWhen` can be removed. + assertEquivalent( + CaseWhen((isNotNullCond, Subtract(Literal(3), Literal(2))) :: + (isNullCond, Literal(1)) :: + (notCond, Add(Literal(6), Literal(-5))) :: + Nil, + Add(Literal(2), Literal(-1))), + Literal(1) + ) + + // For non-deterministic conditions, we don't remove the `CaseWhen` statement. + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Subtract(Literal(3), Literal(2))) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + (EqualTo(Rand(2), Literal(0.5)), Add(Literal(6), Literal(-5))) :: + Nil, + Add(Literal(2), Literal(-1))), + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + (EqualTo(Rand(2), Literal(0.5)), Literal(1)) :: + Nil, + Literal(1)) + ) + + // When we have mixture of deterministic and non-deterministic conditions, we remove + // the deterministic conditions from the tail until a non-deterministic one is seen. + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Subtract(Literal(3), Literal(2))) :: + (NonFoldableLiteral(true), Add(Literal(2), Literal(-1))) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + (NonFoldableLiteral(true), Add(Literal(6), Literal(-5))) :: + (NonFoldableLiteral(false), Literal(1)) :: + Nil, + Add(Literal(2), Literal(-1))), + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (NonFoldableLiteral(true), Literal(1)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + Nil, + Literal(1)) + ) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org