Repository: spark
Updated Branches:
  refs/heads/master f4772fd26 -> 5f3441e54


[SPARK-24893][SQL] Remove the entire CaseWhen if all the outputs are semantic 
equivalence

## What changes were proposed in this pull request?

Similar to SPARK-24890, if all the outputs of `CaseWhen` are semantic 
equivalence, `CaseWhen` can be removed.

## How was this patch tested?

Tests added.

Author: DB Tsai <d_t...@apple.com>

Closes #21852 from dbtsai/short-circuit-when.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5f3441e5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5f3441e5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5f3441e5

Branch: refs/heads/master
Commit: 5f3441e542bfacd81d70bd8b34c22044c8928bff
Parents: f4772fd
Author: DB Tsai <d_t...@apple.com>
Authored: Wed Aug 1 10:31:02 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Wed Aug 1 10:31:02 2018 +0800

----------------------------------------------------------------------
 .../sql/catalyst/optimizer/expressions.scala    | 18 ++++++++
 .../optimizer/SimplifyConditionalSuite.scala    | 48 +++++++++++++++++++-
 2 files changed, 64 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5f3441e5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 4696699..e7b4730 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -416,6 +416,24 @@ object SimplifyConditionals extends Rule[LogicalPlan] with 
PredicateHelper {
         // these branches can be pruned away
         val (h, t) = branches.span(_._1 != TrueLiteral)
         CaseWhen( h :+ t.head, None)
+
+      case e @ CaseWhen(branches, Some(elseValue))
+          if branches.forall(_._2.semanticEquals(elseValue)) =>
+        // For non-deterministic conditions with side effect, we can not 
remove it, or change
+        // the ordering. As a result, we try to remove the deterministic 
conditions from the tail.
+        var hitNonDeterministicCond = false
+        var i = branches.length
+        while (i > 0 && !hitNonDeterministicCond) {
+          hitNonDeterministicCond = !branches(i - 1)._1.deterministic
+          if (!hitNonDeterministicCond) {
+            i -= 1
+          }
+        }
+        if (i == 0) {
+          elseValue
+        } else {
+          e.copy(branches = branches.take(i).map(branch => (branch._1, 
elseValue)))
+        }
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5f3441e5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
index e210874..8ad7c12 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, 
TrueLiteral}
@@ -46,7 +45,9 @@ class SimplifyConditionalSuite extends PlanTest with 
PredicateHelper {
   private val unreachableBranch = (FalseLiteral, Literal(20))
   private val nullBranch = (Literal.create(null, NullType), Literal(30))
 
-  private val testRelation = LocalRelation('a.int)
+  val isNotNullCond = IsNotNull(UnresolvedAttribute(Seq("a")))
+  val isNullCond = IsNull(UnresolvedAttribute("b"))
+  val notCond = Not(UnresolvedAttribute("c"))
 
   test("simplify if") {
     assertEquivalent(
@@ -122,4 +123,47 @@ class SimplifyConditionalSuite extends PlanTest with 
PredicateHelper {
         None),
       CaseWhen(normalBranch :: trueBranch :: Nil, None))
   }
+
+  test("simplify CaseWhen if all the outputs are semantic equivalence") {
+    // When the conditions in `CaseWhen` are all deterministic, `CaseWhen` can 
be removed.
+    assertEquivalent(
+      CaseWhen((isNotNullCond, Subtract(Literal(3), Literal(2))) ::
+        (isNullCond, Literal(1)) ::
+        (notCond, Add(Literal(6), Literal(-5))) ::
+        Nil,
+        Add(Literal(2), Literal(-1))),
+      Literal(1)
+    )
+
+    // For non-deterministic conditions, we don't remove the `CaseWhen` 
statement.
+    assertEquivalent(
+      CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Subtract(Literal(3), 
Literal(2))) ::
+        (LessThan(Rand(1), Literal(0.5)), Literal(1)) ::
+        (EqualTo(Rand(2), Literal(0.5)), Add(Literal(6), Literal(-5))) ::
+        Nil,
+        Add(Literal(2), Literal(-1))),
+      CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) ::
+        (LessThan(Rand(1), Literal(0.5)), Literal(1)) ::
+        (EqualTo(Rand(2), Literal(0.5)), Literal(1)) ::
+        Nil,
+        Literal(1))
+    )
+
+    // When we have mixture of deterministic and non-deterministic conditions, 
we remove
+    // the deterministic conditions from the tail until a non-deterministic 
one is seen.
+    assertEquivalent(
+      CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Subtract(Literal(3), 
Literal(2))) ::
+        (NonFoldableLiteral(true), Add(Literal(2), Literal(-1))) ::
+        (LessThan(Rand(1), Literal(0.5)), Literal(1)) ::
+        (NonFoldableLiteral(true), Add(Literal(6), Literal(-5))) ::
+        (NonFoldableLiteral(false), Literal(1)) ::
+        Nil,
+        Add(Literal(2), Literal(-1))),
+      CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) ::
+        (NonFoldableLiteral(true), Literal(1)) ::
+        (LessThan(Rand(1), Literal(0.5)), Literal(1)) ::
+        Nil,
+        Literal(1))
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to