Repository: spark
Updated Branches:
  refs/heads/master 53aa8316e -> 2f1519def


SPARK-2813:  [SQL] Implement SQRT() directly in Spark SQL

This PR adds a native implementation for SQL SQRT() and thus avoids delegating 
this function to Hive.

Author: William Benton <wi...@redhat.com>

Closes #1750 from willb/spark-2813 and squashes the following commits:

22c8a79 [William Benton] Fixed missed newline from rebase
d673861 [William Benton] Added string coercions for SQRT and associated test 
case
e125df4 [William Benton] Added ExpressionEvaluationSuite test cases for SQRT
7b84bcd [William Benton] SQL SQRT now properly returns NULL for NULL inputs
8256971 [William Benton] added SQRT test to SqlQuerySuite
504d2e5 [William Benton] Added native SQRT implementation


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2f1519de
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2f1519de
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2f1519de

Branch: refs/heads/master
Commit: 2f1519defaba4f3c7d536669f909bfd9e13e4069
Parents: 53aa831
Author: William Benton <wi...@redhat.com>
Authored: Fri Aug 29 15:26:59 2014 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Fri Aug 29 15:26:59 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/catalyst/SqlParser.scala     |  2 ++
 .../sql/catalyst/analysis/HiveTypeCoercion.scala      |  2 ++
 .../spark/sql/catalyst/expressions/arithmetic.scala   | 13 +++++++++++++
 .../expressions/ExpressionEvaluationSuite.scala       | 13 +++++++++++++
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala    | 14 ++++++++++++++
 .../main/scala/org/apache/spark/sql/hive/HiveQl.scala |  2 ++
 6 files changed, 46 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2f1519de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 2c73a80..4f166c0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -122,6 +122,7 @@ class SqlParser extends StandardTokenParsers with 
PackratParsers {
   protected val EXCEPT = Keyword("EXCEPT")
   protected val SUBSTR = Keyword("SUBSTR")
   protected val SUBSTRING = Keyword("SUBSTRING")
+  protected val SQRT = Keyword("SQRT")
 
   // Use reflection to find the reserved words defined in this class.
   protected val reservedWords =
@@ -323,6 +324,7 @@ class SqlParser extends StandardTokenParsers with 
PackratParsers {
     (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ 
expression <~ ")" ^^ {
       case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l)
     } |
+    SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } |
     ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ {
       case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/2f1519de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 15eb598..ecfcd62 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -227,6 +227,8 @@ trait HiveTypeCoercion {
         Sum(Cast(e, DoubleType))
       case Average(e) if e.dataType == StringType =>
         Average(Cast(e, DoubleType))
+      case Sqrt(e) if e.dataType == StringType =>
+        Sqrt(Cast(e, DoubleType))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2f1519de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index aae86a3..56f0428 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -33,6 +33,19 @@ case class UnaryMinus(child: Expression) extends 
UnaryExpression {
   }
 }
 
+case class Sqrt(child: Expression) extends UnaryExpression {
+  type EvaluatedType = Any
+  
+  def dataType = child.dataType
+  override def foldable = child.foldable
+  def nullable = child.nullable
+  override def toString = s"SQRT($child)"
+
+  override def eval(input: Row): Any = {
+    n1(child, input, ((na,a) => math.sqrt(na.toDouble(a))))
+  }
+}
+
 abstract class BinaryArithmetic extends BinaryExpression {
   self: Product =>
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2f1519de/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index f1df817..b961346 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -577,4 +577,17 @@ class ExpressionEvaluationSuite extends FunSuite {
     checkEvaluation(s.substring(0, 2), "ex", row)
     checkEvaluation(s.substring(0), "example", row)
   }
+
+  test("SQRT") {
+    val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24))
+    val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
+    val rowSequence = inputSequence.map(l => new 
GenericRow(Array[Any](l.toDouble)))
+    val d = 'a.double.at(0)
+    
+    for ((row, expected) <- rowSequence zip expectedResults) {
+      checkEvaluation(Sqrt(d), expected, row)
+    }
+
+    checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new 
GenericRow(Array[Any](null)))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2f1519de/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 9b2a36d..4047bc0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -34,6 +34,20 @@ class SQLQuerySuite extends QueryTest {
       "test")
   }
 
+  test("SQRT") {
+    checkAnswer(
+      sql("SELECT SQRT(key) FROM testData"),
+      (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq
+    )
+  }
+  
+  test("SQRT with automatic string casts") {
+    checkAnswer(
+      sql("SELECT SQRT(CAST(key AS STRING)) FROM testData"),
+      (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq
+    )
+  }
+  
   test("SPARK-2407 Added Parser of SQL SUBSTR()") {
     checkAnswer(
       sql("SELECT substr(tableName, 1, 2) FROM tableName"),

http://git-wip-us.apache.org/repos/asf/spark/blob/2f1519de/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index fa3adfd..a4dd6be 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -889,6 +889,7 @@ private[hive] object HiveQl {
   val WHEN = "(?i)WHEN".r
   val CASE = "(?i)CASE".r
   val SUBSTR = "(?i)SUBSTR(?:ING)?".r
+  val SQRT = "(?i)SQRT".r
 
   protected def nodeToExpr(node: Node): Expression = node match {
     /* Attribute References */
@@ -958,6 +959,7 @@ private[hive] object HiveQl {
     case Token(DIV(), left :: right:: Nil) =>
       Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType)
     case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), 
nodeToExpr(right))
+    case Token("TOK_FUNCTION", Token(SQRT(), Nil) :: arg :: Nil) => 
Sqrt(nodeToExpr(arg))
 
     /* Comparisons */
     case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), 
nodeToExpr(right))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to