Repository: spark
Updated Branches:
  refs/heads/master 805541117 -> 678c4da0f


[SPARK-7266] Add ExpectsInputTypes to expressions when possible.

This should gives us better analysis time error messages (rather than runtime) 
and automatic type casting.

Author: Reynold Xin <r...@databricks.com>

Closes #5796 from rxin/expected-input-types and squashes the following commits:

c900760 [Reynold Xin] [SPARK-7266] Add ExpectsInputTypes to expressions when 
possible.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/678c4da0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/678c4da0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/678c4da0

Branch: refs/heads/master
Commit: 678c4da0fa1bbfb6b5a0d3aced7aefa1bbbc193c
Parents: 8055411
Author: Reynold Xin <r...@databricks.com>
Authored: Mon May 4 18:03:07 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon May 4 18:03:07 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    | 58 ++++++++++----------
 .../sql/catalyst/expressions/Expression.scala   |  3 +-
 .../sql/catalyst/expressions/arithmetic.scala   |  4 +-
 .../sql/catalyst/expressions/predicates.scala   | 22 +++++---
 .../catalyst/expressions/stringOperations.scala | 40 ++++++++------
 5 files changed, 71 insertions(+), 56 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/678c4da0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 73c9a1c..831fb4f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -239,37 +239,43 @@ trait HiveTypeCoercion {
         a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
 
       // we should cast all timestamp/date/string compare into string compare
-      case p: BinaryPredicate if p.left.dataType == StringType
-        && p.right.dataType == DateType =>
+      case p: BinaryComparison if p.left.dataType == StringType &&
+                                  p.right.dataType == DateType =>
         p.makeCopy(Array(p.left, Cast(p.right, StringType)))
-      case p: BinaryPredicate if p.left.dataType == DateType
-        && p.right.dataType == StringType =>
+      case p: BinaryComparison if p.left.dataType == DateType &&
+                                  p.right.dataType == StringType =>
         p.makeCopy(Array(Cast(p.left, StringType), p.right))
-      case p: BinaryPredicate if p.left.dataType == StringType
-        && p.right.dataType == TimestampType =>
+      case p: BinaryComparison if p.left.dataType == StringType &&
+                                  p.right.dataType == TimestampType =>
         p.makeCopy(Array(p.left, Cast(p.right, StringType)))
-      case p: BinaryPredicate if p.left.dataType == TimestampType
-        && p.right.dataType == StringType =>
+      case p: BinaryComparison if p.left.dataType == TimestampType &&
+                                  p.right.dataType == StringType =>
         p.makeCopy(Array(Cast(p.left, StringType), p.right))
-      case p: BinaryPredicate if p.left.dataType == TimestampType
-        && p.right.dataType == DateType =>
+      case p: BinaryComparison if p.left.dataType == TimestampType &&
+                                  p.right.dataType == DateType =>
         p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
-      case p: BinaryPredicate if p.left.dataType == DateType
-        && p.right.dataType == TimestampType =>
+      case p: BinaryComparison if p.left.dataType == DateType &&
+                                  p.right.dataType == TimestampType =>
         p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
 
-      case p: BinaryPredicate if p.left.dataType == StringType && 
p.right.dataType != StringType =>
+      case p: BinaryComparison if p.left.dataType == StringType &&
+                                  p.right.dataType != StringType =>
         p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
-      case p: BinaryPredicate if p.left.dataType != StringType && 
p.right.dataType == StringType =>
+      case p: BinaryComparison if p.left.dataType != StringType &&
+                                  p.right.dataType == StringType =>
         p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))
 
-      case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == 
StringType) =>
+      case i @ In(a, b) if a.dataType == DateType &&
+                           b.forall(_.dataType == StringType) =>
         i.makeCopy(Array(Cast(a, StringType), b))
-      case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType 
== StringType) =>
+      case i @ In(a, b) if a.dataType == TimestampType &&
+                           b.forall(_.dataType == StringType) =>
         i.makeCopy(Array(Cast(a, StringType), b))
-      case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == 
TimestampType) =>
+      case i @ In(a, b) if a.dataType == DateType &&
+                           b.forall(_.dataType == TimestampType) =>
         i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
-      case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType 
== DateType) =>
+      case i @ In(a, b) if a.dataType == TimestampType &&
+                           b.forall(_.dataType == DateType) =>
         i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
 
       case Sum(e) if e.dataType == StringType =>
@@ -420,19 +426,19 @@ trait HiveTypeCoercion {
           )
 
         case LessThan(e1 @ DecimalType.Expression(p1, s1),
-        e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+                      e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 
!= s2 =>
           LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, 
DecimalType.Unlimited))
 
         case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
-        e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+                             e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 
|| s1 != s2 =>
           LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, 
DecimalType.Unlimited))
 
         case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
-        e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+                         e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || 
s1 != s2 =>
           GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, 
DecimalType.Unlimited))
 
         case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
-        e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+                                e2 @ DecimalType.Expression(p2, s2)) if p1 != 
p2 || s1 != s2 =>
           GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, 
DecimalType.Unlimited))
 
         // Promote integers inside a binary expression with fixed-precision 
decimals to decimals,
@@ -481,8 +487,8 @@ trait HiveTypeCoercion {
       // No need to change the EqualNullSafe operators, too
       case e: EqualNullSafe => e
       // Otherwise turn them to Byte types so that there exists and ordering.
-      case p: BinaryComparison
-          if p.left.dataType == BooleanType && p.right.dataType == BooleanType 
=>
+      case p: BinaryComparison if p.left.dataType == BooleanType &&
+                                  p.right.dataType == BooleanType =>
         p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType)))
     }
   }
@@ -564,10 +570,6 @@ trait HiveTypeCoercion {
       case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
       case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
 
-      // Compatible with Hive
-      case Substring(e, start, len) if e.dataType != StringType =>
-        Substring(Cast(e, StringType), start, len)
-
       // Coalesce should return the first non-null value, which could be any 
column
       // from the list. So we need to make sure the return type is 
deterministic and
       // compatible with every child column.

http://git-wip-us.apache.org/repos/asf/spark/blob/678c4da0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 1d71c1b..4fd1bc4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
 import org.apache.spark.sql.catalyst.trees
 import org.apache.spark.sql.catalyst.trees.TreeNode
 import org.apache.spark.sql.types._
@@ -86,6 +85,8 @@ abstract class BinaryExpression extends Expression with 
trees.BinaryNode[Express
 
   override def foldable: Boolean = left.foldable && right.foldable
 
+  override def nullable: Boolean = left.nullable || right.nullable
+
   override def toString: String = s"($left $symbol $right)"
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/678c4da0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 140ccd8..c7a37ad 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -74,14 +74,12 @@ abstract class BinaryArithmetic extends BinaryExpression {
 
   type EvaluatedType = Any
 
-  def nullable: Boolean = left.nullable || right.nullable
-
   override lazy val resolved =
     left.resolved && right.resolved &&
     left.dataType == right.dataType &&
     !DecimalType.isFixed(left.dataType)
 
-  def dataType: DataType = {
+  override def dataType: DataType = {
     if (!resolved) {
       throw new UnresolvedException(this,
         s"datatype. Can not resolve due to differing types ${left.dataType}, 
${right.dataType}")

http://git-wip-us.apache.org/repos/asf/spark/blob/678c4da0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 9cb00cb..26c38c5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -70,16 +70,14 @@ trait PredicateHelper {
     expr.references.subsetOf(plan.outputSet)
 }
 
-abstract class BinaryPredicate extends BinaryExpression with Predicate {
-  self: Product =>
-  override def nullable: Boolean = left.nullable || right.nullable
-}
 
-case class Not(child: Expression) extends UnaryExpression with Predicate {
+case class Not(child: Expression) extends UnaryExpression with Predicate with 
ExpectsInputTypes {
   override def foldable: Boolean = child.foldable
   override def nullable: Boolean = child.nullable
   override def toString: String = s"NOT $child"
 
+  override def expectedChildTypes: Seq[DataType] = Seq(BooleanType)
+
   override def eval(input: Row): Any = {
     child.eval(input) match {
       case null => null
@@ -120,7 +118,11 @@ case class InSet(value: Expression, hset: Set[Any])
   }
 }
 
-case class And(left: Expression, right: Expression) extends BinaryPredicate {
+case class And(left: Expression, right: Expression)
+  extends BinaryExpression with Predicate with ExpectsInputTypes {
+
+  override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, 
BooleanType)
+
   override def symbol: String = "&&"
 
   override def eval(input: Row): Any = {
@@ -142,7 +144,11 @@ case class And(left: Expression, right: Expression) 
extends BinaryPredicate {
   }
 }
 
-case class Or(left: Expression, right: Expression) extends BinaryPredicate {
+case class Or(left: Expression, right: Expression)
+  extends BinaryExpression with Predicate with ExpectsInputTypes {
+
+  override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, 
BooleanType)
+
   override def symbol: String = "||"
 
   override def eval(input: Row): Any = {
@@ -164,7 +170,7 @@ case class Or(left: Expression, right: Expression) extends 
BinaryPredicate {
   }
 }
 
-abstract class BinaryComparison extends BinaryPredicate {
+abstract class BinaryComparison extends BinaryExpression with Predicate {
   self: Product =>
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/678c4da0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index d597bf7..d6f23df 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -22,7 +22,7 @@ import java.util.regex.Pattern
 import org.apache.spark.sql.catalyst.analysis.UnresolvedException
 import org.apache.spark.sql.types._
 
-trait StringRegexExpression {
+trait StringRegexExpression extends ExpectsInputTypes {
   self: BinaryExpression =>
 
   type EvaluatedType = Any
@@ -32,6 +32,7 @@ trait StringRegexExpression {
 
   override def nullable: Boolean = left.nullable || right.nullable
   override def dataType: DataType = BooleanType
+  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
 
   // try cache the pattern for Literal
   private lazy val cache: Pattern = right match {
@@ -57,11 +58,11 @@ trait StringRegexExpression {
       if(r == null) {
         null
       } else {
-        val regex = pattern(r.asInstanceOf[UTF8String].toString)
+        val regex = pattern(r.asInstanceOf[UTF8String].toString())
         if(regex == null) {
           null
         } else {
-          matches(regex, l.asInstanceOf[UTF8String].toString)
+          matches(regex, l.asInstanceOf[UTF8String].toString())
         }
       }
     }
@@ -110,7 +111,7 @@ case class RLike(left: Expression, right: Expression)
   override def matches(regex: Pattern, str: String): Boolean = 
regex.matcher(str).find(0)
 }
 
-trait CaseConversionExpression {
+trait CaseConversionExpression extends ExpectsInputTypes {
   self: UnaryExpression =>
 
   type EvaluatedType = Any
@@ -118,8 +119,9 @@ trait CaseConversionExpression {
   def convert(v: UTF8String): UTF8String
 
   override def foldable: Boolean = child.foldable
-  def nullable: Boolean = child.nullable
-  def dataType: DataType = StringType
+  override def nullable: Boolean = child.nullable
+  override def dataType: DataType = StringType
+  override def expectedChildTypes: Seq[DataType] = Seq(StringType)
 
   override def eval(input: Row): Any = {
     val evaluated = child.eval(input)
@@ -136,7 +138,7 @@ trait CaseConversionExpression {
  */
 case class Upper(child: Expression) extends UnaryExpression with 
CaseConversionExpression {
   
-  override def convert(v: UTF8String): UTF8String = v.toUpperCase
+  override def convert(v: UTF8String): UTF8String = v.toUpperCase()
 
   override def toString: String = s"Upper($child)"
 }
@@ -146,21 +148,21 @@ case class Upper(child: Expression) extends 
UnaryExpression with CaseConversionE
  */
 case class Lower(child: Expression) extends UnaryExpression with 
CaseConversionExpression {
   
-  override def convert(v: UTF8String): UTF8String = v.toLowerCase
+  override def convert(v: UTF8String): UTF8String = v.toLowerCase()
 
   override def toString: String = s"Lower($child)"
 }
 
 /** A base trait for functions that compare two strings, returning a boolean. 
*/
 trait StringComparison {
-  self: BinaryPredicate =>
+  self: BinaryExpression =>
+
+  def compare(l: UTF8String, r: UTF8String): Boolean
 
   override type EvaluatedType = Any
 
   override def nullable: Boolean = left.nullable || right.nullable
 
-  def compare(l: UTF8String, r: UTF8String): Boolean
-
   override def eval(input: Row): Any = {
     val leftEval = left.eval(input)
     if(leftEval == null) {
@@ -181,31 +183,35 @@ trait StringComparison {
  * A function that returns true if the string `left` contains the string 
`right`.
  */
 case class Contains(left: Expression, right: Expression)
-    extends BinaryPredicate with StringComparison {
+    extends BinaryExpression with Predicate with StringComparison with 
ExpectsInputTypes {
   override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
+  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
 }
 
 /**
  * A function that returns true if the string `left` starts with the string 
`right`.
  */
 case class StartsWith(left: Expression, right: Expression)
-    extends BinaryPredicate with StringComparison {
+    extends BinaryExpression with Predicate with StringComparison with 
ExpectsInputTypes {
   override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
+  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
 }
 
 /**
  * A function that returns true if the string `left` ends with the string 
`right`.
  */
 case class EndsWith(left: Expression, right: Expression)
-    extends BinaryPredicate with StringComparison {
+    extends BinaryExpression with Predicate with StringComparison with 
ExpectsInputTypes {
   override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
+  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
 }
 
 /**
  * A function that takes a substring of its first argument starting at a given 
position.
  * Defined for String and Binary types.
  */
-case class Substring(str: Expression, pos: Expression, len: Expression) 
extends Expression {
+case class Substring(str: Expression, pos: Expression, len: Expression)
+  extends Expression with ExpectsInputTypes {
   
   type EvaluatedType = Any
 
@@ -219,6 +225,8 @@ case class Substring(str: Expression, pos: Expression, len: 
Expression) extends
     if (str.dataType == BinaryType) str.dataType else StringType
   }
 
+  override def expectedChildTypes: Seq[DataType] = Seq(StringType, 
IntegerType, IntegerType)
+
   override def children: Seq[Expression] = str :: pos :: len :: Nil
 
   @inline
@@ -258,7 +266,7 @@ case class Substring(str: Expression, pos: Expression, len: 
Expression) extends
           val (st, end) = slicePos(start, length, () => ba.length)
           ba.slice(st, end)
         case s: UTF8String =>
-          val (st, end) = slicePos(start, length, () => s.length)
+          val (st, end) = slicePos(start, length, () => s.length())
           s.slice(st, end)
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to