Repository: spark
Updated Branches:
  refs/heads/master 700312e12 -> bc0d76a24


[SQL] Simplifies binary node pattern matching

This PR is a simpler version of #2764, and adds `unapply` methods to the 
following binary nodes for simpler pattern matching:

- `BinaryExpression`
- `BinaryComparison`
- `BinaryArithmetics`

This enables nested pattern matching for binary nodes. For example, the 
following pattern matching

```scala
case p: BinaryComparison if p.left.dataType == StringType &&
                            p.right.dataType == DateType =>
  p.makeCopy(Array(p.left, Cast(p.right, StringType)))
```

can be simplified to

```scala
case p  BinaryComparison(l  StringType(), r  DateType()) =>
  p.makeCopy(Array(l, Cast(r, StringType)))
```

Author: Cheng Lian <l...@databricks.com>

Closes #6537 from liancheng/binary-node-patmat and squashes the following 
commits:

a3bf5fe [Cheng Lian] Fixes compilation error introduced while rebasing
b738986 [Cheng Lian] Renames `l`/`r` to `left`/`right` or `lhs`/`rhs`
14900ae [Cheng Lian] Simplifies binary node pattern matching


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bc0d76a2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bc0d76a2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bc0d76a2

Branch: refs/heads/master
Commit: bc0d76a246cc534234b96a661d70feb94b26538c
Parents: 700312e
Author: Cheng Lian <l...@databricks.com>
Authored: Fri Jun 5 23:06:19 2015 +0800
Committer: Cheng Lian <l...@databricks.com>
Committed: Fri Jun 5 23:06:19 2015 +0800

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    | 215 +++++++++----------
 .../sql/catalyst/expressions/Expression.scala   |   4 +
 .../sql/catalyst/expressions/arithmetic.scala   |   4 +
 .../sql/catalyst/expressions/predicates.scala   |   5 +-
 .../sql/catalyst/optimizer/Optimizer.scala      |  19 +-
 5 files changed, 119 insertions(+), 128 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bc0d76a2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index b064600..9b8a08a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -130,7 +130,7 @@ trait HiveTypeCoercion {
    * the appropriate numeric equivalent.
    */
   object ConvertNaNs extends Rule[LogicalPlan] {
-    private val stringNaN = Literal("NaN")
+    private val StringNaN = Literal("NaN")
 
     def apply(plan: LogicalPlan): LogicalPlan = plan transform {
       case q: LogicalPlan => q transformExpressions {
@@ -138,20 +138,20 @@ trait HiveTypeCoercion {
         case e if !e.childrenResolved => e
 
         /* Double Conversions */
-        case b: BinaryExpression if b.left == stringNaN && b.right.dataType == 
DoubleType =>
-          b.makeCopy(Array(b.right, Literal(Double.NaN)))
-        case b: BinaryExpression if b.left.dataType == DoubleType && b.right 
== stringNaN =>
-          b.makeCopy(Array(Literal(Double.NaN), b.left))
-        case b: BinaryExpression if b.left == stringNaN && b.right == 
stringNaN =>
-          b.makeCopy(Array(Literal(Double.NaN), b.left))
+        case b @ BinaryExpression(StringNaN, right @ DoubleType()) =>
+          b.makeCopy(Array(Literal(Double.NaN), right))
+        case b @ BinaryExpression(left @ DoubleType(), StringNaN) =>
+          b.makeCopy(Array(left, Literal(Double.NaN)))
 
         /* Float Conversions */
-        case b: BinaryExpression if b.left == stringNaN && b.right.dataType == 
FloatType =>
-          b.makeCopy(Array(b.right, Literal(Float.NaN)))
-        case b: BinaryExpression if b.left.dataType == FloatType && b.right == 
stringNaN =>
-          b.makeCopy(Array(Literal(Float.NaN), b.left))
-        case b: BinaryExpression if b.left == stringNaN && b.right == 
stringNaN =>
-          b.makeCopy(Array(Literal(Float.NaN), b.left))
+        case b @ BinaryExpression(StringNaN, right @ FloatType()) =>
+          b.makeCopy(Array(Literal(Float.NaN), right))
+        case b @ BinaryExpression(left @ FloatType(), StringNaN) =>
+          b.makeCopy(Array(left, Literal(Float.NaN)))
+
+        /* Use float NaN by default to avoid unnecessary type widening */
+        case b @ BinaryExpression(left @ StringNaN, StringNaN) =>
+          b.makeCopy(Array(left, Literal(Float.NaN)))
       }
     }
   }
@@ -184,21 +184,25 @@ trait HiveTypeCoercion {
       case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
         val castedInput = left.output.zip(right.output).map {
           // When a string is found on one side, make the other side a string 
too.
-          case (l, r) if l.dataType == StringType && r.dataType != StringType 
=>
-            (l, Alias(Cast(r, StringType), r.name)())
-          case (l, r) if l.dataType != StringType && r.dataType == StringType 
=>
-            (Alias(Cast(l, StringType), l.name)(), r)
-
-          case (l, r) if l.dataType != r.dataType =>
-            logDebug(s"Resolving mismatched union input ${l.dataType}, 
${r.dataType}")
-            findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { 
widestType =>
+          case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != 
StringType =>
+            (lhs, Alias(Cast(rhs, StringType), rhs.name)())
+          case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == 
StringType =>
+            (Alias(Cast(lhs, StringType), lhs.name)(), rhs)
+
+          case (lhs, rhs) if lhs.dataType != rhs.dataType =>
+            logDebug(s"Resolving mismatched union input ${lhs.dataType}, 
${rhs.dataType}")
+            findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { 
widestType =>
               val newLeft =
-                if (l.dataType == widestType) l else Alias(Cast(l, 
widestType), l.name)()
+                if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, 
widestType), lhs.name)()
               val newRight =
-                if (r.dataType == widestType) r else Alias(Cast(r, 
widestType), r.name)()
+                if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, 
widestType), rhs.name)()
 
               (newLeft, newRight)
-            }.getOrElse((l, r)) // If there is no applicable conversion, leave 
expression unchanged.
+            }.getOrElse {
+              // If there is no applicable conversion, leave expression 
unchanged.
+              (lhs, rhs)
+            }
+
           case other => other
         }
 
@@ -227,12 +231,10 @@ trait HiveTypeCoercion {
         // Skip nodes who's children have not been resolved yet.
         case e if !e.childrenResolved => e
 
-        case b: BinaryExpression if b.left.dataType != b.right.dataType =>
-          findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { 
widestType =>
-            val newLeft =
-              if (b.left.dataType == widestType) b.left else Cast(b.left, 
widestType)
-            val newRight =
-              if (b.right.dataType == widestType) b.right else Cast(b.right, 
widestType)
+        case b @ BinaryExpression(left, right) if left.dataType != 
right.dataType =>
+          findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { 
widestType =>
+            val newLeft = if (left.dataType == widestType) left else 
Cast(left, widestType)
+            val newRight = if (right.dataType == widestType) right else 
Cast(right, widestType)
             b.makeCopy(Array(newLeft, newRight))
           }.getOrElse(b)  // If there is no applicable conversion, leave 
expression unchanged.
       }
@@ -247,57 +249,42 @@ trait HiveTypeCoercion {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
 
-      case a: BinaryArithmetic if a.left.dataType == StringType =>
-        a.makeCopy(Array(Cast(a.left, DoubleType), a.right))
-      case a: BinaryArithmetic if a.right.dataType == StringType =>
-        a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
+      case a @ BinaryArithmetic(left @ StringType(), r) =>
+        a.makeCopy(Array(Cast(left, DoubleType), r))
+      case a @ BinaryArithmetic(left, right @ StringType()) =>
+        a.makeCopy(Array(left, Cast(right, DoubleType)))
 
       // we should cast all timestamp/date/string compare into string compare
-      case p: BinaryComparison if p.left.dataType == StringType &&
-                                  p.right.dataType == DateType =>
-        p.makeCopy(Array(p.left, Cast(p.right, StringType)))
-      case p: BinaryComparison if p.left.dataType == DateType &&
-                                  p.right.dataType == StringType =>
-        p.makeCopy(Array(Cast(p.left, StringType), p.right))
-      case p: BinaryComparison if p.left.dataType == StringType &&
-                                  p.right.dataType == TimestampType =>
-        p.makeCopy(Array(Cast(p.left, TimestampType), p.right))
-      case p: BinaryComparison if p.left.dataType == TimestampType &&
-                                  p.right.dataType == StringType =>
-        p.makeCopy(Array(p.left, Cast(p.right, TimestampType)))
-      case p: BinaryComparison if p.left.dataType == TimestampType &&
-                                  p.right.dataType == DateType =>
-        p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
-      case p: BinaryComparison if p.left.dataType == DateType &&
-                                  p.right.dataType == TimestampType =>
-        p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
-
-      case p: BinaryComparison if p.left.dataType == StringType &&
-                                  p.right.dataType != StringType =>
-        p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
-      case p: BinaryComparison if p.left.dataType != StringType &&
-                                  p.right.dataType == StringType =>
-        p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))
-
-      case i @ In(a, b) if a.dataType == DateType &&
-                           b.forall(_.dataType == StringType) =>
+      case p @ BinaryComparison(left @ StringType(), right @ DateType()) =>
+        p.makeCopy(Array(left, Cast(right, StringType)))
+      case p @ BinaryComparison(left @ DateType(), right @ StringType()) =>
+        p.makeCopy(Array(Cast(left, StringType), right))
+      case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) 
=>
+        p.makeCopy(Array(Cast(left, TimestampType), right))
+      case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) 
=>
+        p.makeCopy(Array(left, Cast(right, TimestampType)))
+      case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) =>
+        p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
+      case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) =>
+        p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
+
+      case p @ BinaryComparison(left @ StringType(), right) if right.dataType 
!= StringType =>
+        p.makeCopy(Array(Cast(left, DoubleType), right))
+      case p @ BinaryComparison(left, right @ StringType()) if left.dataType 
!= StringType =>
+        p.makeCopy(Array(left, Cast(right, DoubleType)))
+
+      case i @ In(a @ DateType(), b) if b.forall(_.dataType == StringType) =>
         i.makeCopy(Array(Cast(a, StringType), b))
-      case i @ In(a, b) if a.dataType == TimestampType &&
-                           b.forall(_.dataType == StringType) =>
+      case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == 
StringType) =>
         i.makeCopy(Array(a, b.map(Cast(_, TimestampType))))
-      case i @ In(a, b) if a.dataType == DateType &&
-                           b.forall(_.dataType == TimestampType) =>
+      case i @ In(a @ DateType(), b) if b.forall(_.dataType == TimestampType) 
=>
         i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
-      case i @ In(a, b) if a.dataType == TimestampType &&
-                           b.forall(_.dataType == DateType) =>
+      case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == DateType) 
=>
         i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
 
-      case Sum(e) if e.dataType == StringType =>
-        Sum(Cast(e, DoubleType))
-      case Average(e) if e.dataType == StringType =>
-        Average(Cast(e, DoubleType))
-      case Sqrt(e) if e.dataType == StringType =>
-        Sqrt(Cast(e, DoubleType))
+      case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
+      case Average(e @ StringType()) => Average(Cast(e, DoubleType))
+      case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType))
     }
   }
 
@@ -379,22 +366,22 @@ trait HiveTypeCoercion {
       // fix decimal precision for union
       case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
         val castedInput = left.output.zip(right.output).map {
-          case (l, r) if l.dataType != r.dataType =>
-            (l.dataType, r.dataType) match {
+          case (lhs, rhs) if lhs.dataType != rhs.dataType =>
+            (lhs.dataType, rhs.dataType) match {
               case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
                 // Union decimals with precision/scale p1/s2 and p2/s2  will 
be promoted to
                 // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
                 val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - 
s2), max(s1, s2))
-                (Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, 
fixedType), r.name)())
+                (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, 
fixedType), rhs.name)())
               case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) 
=>
-                (Alias(Cast(l, intTypeToFixed(t)), l.name)(), r)
+                (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs)
               case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) 
=>
-                (l, Alias(Cast(r, intTypeToFixed(t)), r.name)())
+                (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)())
               case (t, DecimalType.Fixed(p, s)) if 
floatTypeToFixed.contains(t) =>
-                (Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r)
+                (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs)
               case (DecimalType.Fixed(p, s), t) if 
floatTypeToFixed.contains(t) =>
-                (l, Alias(Cast(r, floatTypeToFixed(t)), r.name)())
-              case _ => (l, r)
+                (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)())
+              case _ => (lhs, rhs)
             }
           case other => other
         }
@@ -467,16 +454,16 @@ trait HiveTypeCoercion {
 
         // Promote integers inside a binary expression with fixed-precision 
decimals to decimals,
         // and fixed-precision decimals in an expression with floats / doubles 
to doubles
-        case b: BinaryExpression if b.left.dataType != b.right.dataType =>
-          (b.left.dataType, b.right.dataType) match {
+        case b @ BinaryExpression(left, right) if left.dataType != 
right.dataType =>
+          (left.dataType, right.dataType) match {
             case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
-              b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
+              b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right))
             case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
-              b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
+              b.makeCopy(Array(left, Cast(right, intTypeToFixed(t))))
             case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
-              b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
+              b.makeCopy(Array(left, Cast(right, DoubleType)))
             case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
-              b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
+              b.makeCopy(Array(Cast(left, DoubleType), right))
             case _ =>
               b
           }
@@ -525,31 +512,31 @@ trait HiveTypeCoercion {
       // all other cases are considered as false.
 
       // We may simplify the expression if one side is literal numeric values
-      case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
-        if trueValues.contains(value) => l
-      case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
-        if falseValues.contains(value) => Not(l)
-      case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
-        if trueValues.contains(value) => r
-      case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
-        if falseValues.contains(value) => Not(r)
-      case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
-        if trueValues.contains(value) => And(IsNotNull(l), l)
-      case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
-        if falseValues.contains(value) => And(IsNotNull(l), Not(l))
-      case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
-        if trueValues.contains(value) => And(IsNotNull(r), r)
-      case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
-        if falseValues.contains(value) => And(IsNotNull(r), Not(r))
-
-      case EqualTo(l @ BooleanType(), r @ NumericType()) =>
-        transform(l , r)
-      case EqualTo(l @ NumericType(), r @ BooleanType()) =>
-        transform(r, l)
-      case EqualNullSafe(l @ BooleanType(), r @ NumericType()) =>
-        transformNullSafe(l, r)
-      case EqualNullSafe(l @ NumericType(), r @ BooleanType()) =>
-        transformNullSafe(r, l)
+      case EqualTo(left @ BooleanType(), Literal(value, _: NumericType))
+        if trueValues.contains(value) => left
+      case EqualTo(left @ BooleanType(), Literal(value, _: NumericType))
+        if falseValues.contains(value) => Not(left)
+      case EqualTo(Literal(value, _: NumericType), right @ BooleanType())
+        if trueValues.contains(value) => right
+      case EqualTo(Literal(value, _: NumericType), right @ BooleanType())
+        if falseValues.contains(value) => Not(right)
+      case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType))
+        if trueValues.contains(value) => And(IsNotNull(left), left)
+      case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType))
+        if falseValues.contains(value) => And(IsNotNull(left), Not(left))
+      case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType())
+        if trueValues.contains(value) => And(IsNotNull(right), right)
+      case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType())
+        if falseValues.contains(value) => And(IsNotNull(right), Not(right))
+
+      case EqualTo(left @ BooleanType(), right @ NumericType()) =>
+        transform(left , right)
+      case EqualTo(left @ NumericType(), right @ BooleanType()) =>
+        transform(right, left)
+      case EqualNullSafe(left @ BooleanType(), right @ NumericType()) =>
+        transformNullSafe(left, right)
+      case EqualNullSafe(left @ NumericType(), right @ BooleanType()) =>
+        transformNullSafe(right, left)
     }
   }
 
@@ -630,7 +617,7 @@ trait HiveTypeCoercion {
       case d: Divide if d.dataType == DoubleType => d
       case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
 
-      case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
+      case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right, 
DoubleType))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bc0d76a2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 3cf851a..b2b9d1a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -118,6 +118,10 @@ abstract class BinaryExpression extends Expression with 
trees.BinaryNode[Express
   override def toString: String = s"($left $symbol $right)"
 }
 
+private[sql] object BinaryExpression {
+  def unapply(e: BinaryExpression): Option[(Expression, Expression)] = 
Some((e.left, e.right))
+}
+
 abstract class LeafExpression extends Expression with 
trees.LeafNode[Expression] {
   self: Product =>
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bc0d76a2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 2ac53f8..a3770f9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -118,6 +118,10 @@ abstract class BinaryArithmetic extends BinaryExpression {
     sys.error(s"BinaryArithmetics must override either eval or evalInternal")
 }
 
+private[sql] object BinaryArithmetic {
+  def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = 
Some((e.left, e.right))
+}
+
 case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
   override def symbol: String = "+"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bc0d76a2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 807021d..58273b1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -202,9 +202,8 @@ abstract class BinaryComparison extends BinaryExpression 
with Predicate {
     sys.error(s"BinaryComparisons must override either eval or evalInternal")
 }
 
-object BinaryComparison {
-  def unapply(b: BinaryComparison): Option[(Expression, Expression)] =
-    Some((b.left, b.right))
+private[sql] object BinaryComparison {
+  def unapply(e: BinaryComparison): Option[(Expression, Expression)] = 
Some((e.left, e.right))
 }
 
 case class EqualTo(left: Expression, right: Expression) extends 
BinaryComparison {

http://git-wip-us.apache.org/repos/asf/spark/blob/bc0d76a2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0a17b10..c16f08d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -266,7 +266,7 @@ object NullPropagation extends Rule[LogicalPlan] {
         if (newChildren.length == 0) {
           Literal.create(null, e.dataType)
         } else if (newChildren.length == 1) {
-          newChildren(0)
+          newChildren.head
         } else {
           Coalesce(newChildren)
         }
@@ -280,21 +280,18 @@ object NullPropagation extends Rule[LogicalPlan] {
       case e: MinOf => e
 
       // Put exceptional cases above if any
-      case e: BinaryArithmetic => e.children match {
-        case Literal(null, _) :: right :: Nil => Literal.create(null, 
e.dataType)
-        case left :: Literal(null, _) :: Nil => Literal.create(null, 
e.dataType)
-        case _ => e
-      }
-      case e: BinaryComparison => e.children match {
-        case Literal(null, _) :: right :: Nil => Literal.create(null, 
e.dataType)
-        case left :: Literal(null, _) :: Nil => Literal.create(null, 
e.dataType)
-        case _ => e
-      }
+      case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, 
e.dataType)
+      case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, 
e.dataType)
+
+      case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, 
e.dataType)
+      case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, 
e.dataType)
+
       case e: StringRegexExpression => e.children match {
         case Literal(null, _) :: right :: Nil => Literal.create(null, 
e.dataType)
         case left :: Literal(null, _) :: Nil => Literal.create(null, 
e.dataType)
         case _ => e
       }
+
       case e: StringComparison => e.children match {
         case Literal(null, _) :: right :: Nil => Literal.create(null, 
e.dataType)
         case left :: Literal(null, _) :: Nil => Literal.create(null, 
e.dataType)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to