Repository: spark
Updated Branches:
  refs/heads/master c26b09216 -> d4c341589


[SPARK-24890][SQL] Short circuiting the `if` condition when `trueValue` and 
`falseValue` are the same

## What changes were proposed in this pull request?

When `trueValue` and `falseValue` are semantic equivalence, the condition 
expression in `if` can be removed to avoid extra computation in runtime.

## How was this patch tested?

Test added.

Author: DB Tsai <d_t...@apple.com>

Closes #21848 from dbtsai/short-circuit-if.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d4c34158
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d4c34158
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d4c34158

Branch: refs/heads/master
Commit: d4c341589499099654ed4febf235f19897a21601
Parents: c26b092
Author: DB Tsai <d_t...@apple.com>
Authored: Tue Jul 24 20:21:11 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Tue Jul 24 20:21:11 2018 -0700

----------------------------------------------------------------------
 .../sql/catalyst/optimizer/expressions.scala    |  7 ++++--
 .../optimizer/SimplifyConditionalSuite.scala    | 24 +++++++++++++++++++-
 .../apache/spark/sql/test/SQLTestUtils.scala    |  2 +-
 3 files changed, 29 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d4c34158/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index cf17f59..4696699 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -390,6 +390,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with 
PredicateHelper {
       case If(TrueLiteral, trueValue, _) => trueValue
       case If(FalseLiteral, _, falseValue) => falseValue
       case If(Literal(null, _), _, falseValue) => falseValue
+      case If(cond, trueValue, falseValue)
+        if cond.deterministic && trueValue.semanticEquals(falseValue) => 
trueValue
 
       case e @ CaseWhen(branches, elseValue) if branches.exists(x => 
falseOrNullLiteral(x._1)) =>
         // If there are branches that are always false, remove them.
@@ -403,14 +405,14 @@ object SimplifyConditionals extends Rule[LogicalPlan] 
with PredicateHelper {
           e.copy(branches = newBranches)
         }
 
-      case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == 
Some(TrueLiteral) =>
+      case CaseWhen(branches, _) if 
branches.headOption.map(_._1).contains(TrueLiteral) =>
         // If the first branch is a true literal, remove the entire CaseWhen 
and use the value
         // from that. Note that CaseWhen.branches should never be empty, and 
as a result the
         // headOption (rather than head) added above is just an extra (and 
unnecessary) safeguard.
         branches.head._2
 
       case CaseWhen(branches, _) if branches.exists(_._1 == TrueLiteral) =>
-        // a branc with a TRue condition eliminates all following branches,
+        // a branch with a true condition eliminates all following branches,
         // these branches can be pruned away
         val (h, t) = branches.span(_._1 != TrueLiteral)
         CaseWhen( h :+ t.head, None)
@@ -651,6 +653,7 @@ object SimplifyCaseConversionExpressions extends 
Rule[LogicalPlan] {
   }
 }
 
+
 /**
  * Combine nested [[Concat]] expressions.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/d4c34158/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
index b597c8e..e210874 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, 
TrueLiteral}
@@ -29,7 +31,8 @@ import org.apache.spark.sql.types.{IntegerType, NullType}
 class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
 
   object Optimize extends RuleExecutor[LogicalPlan] {
-    val batches = Batch("SimplifyConditionals", FixedPoint(50), 
SimplifyConditionals) :: Nil
+    val batches = Batch("SimplifyConditionals", FixedPoint(50),
+      BooleanSimplification, ConstantFolding, SimplifyConditionals) :: Nil
   }
 
   protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
@@ -43,6 +46,8 @@ class SimplifyConditionalSuite extends PlanTest with 
PredicateHelper {
   private val unreachableBranch = (FalseLiteral, Literal(20))
   private val nullBranch = (Literal.create(null, NullType), Literal(30))
 
+  private val testRelation = LocalRelation('a.int)
+
   test("simplify if") {
     assertEquivalent(
       If(TrueLiteral, Literal(10), Literal(20)),
@@ -57,6 +62,23 @@ class SimplifyConditionalSuite extends PlanTest with 
PredicateHelper {
       Literal(20))
   }
 
+  test("remove unnecessary if when the outputs are semantic equivalence") {
+    assertEquivalent(
+      If(IsNotNull(UnresolvedAttribute("a")),
+        Subtract(Literal(10), Literal(1)),
+        Add(Literal(6), Literal(3))),
+      Literal(9))
+
+    // For non-deterministic condition, we don't remove the `If` statement.
+    assertEquivalent(
+      If(GreaterThan(Rand(0), Literal(0.5)),
+        Subtract(Literal(10), Literal(1)),
+        Add(Literal(6), Literal(3))),
+      If(GreaterThan(Rand(0), Literal(0.5)),
+        Literal(9),
+        Literal(9)))
+  }
+
   test("remove unreachable branches") {
     // i.e. removing branches whose conditions are always false
     assertEquivalent(

http://git-wip-us.apache.org/repos/asf/spark/blob/d4c34158/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index e562be8..ac70488 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -393,7 +393,7 @@ private[sql] trait SQLTestUtilsBase
   }
 
   /**
-   * Returns full path to the given file in the resouce folder
+   * Returns full path to the given file in the resource folder
    */
   protected def testFile(fileName: String): String = {
     Thread.currentThread().getContextClassLoader.getResource(fileName).toString


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to