Repository: spark Updated Branches: refs/heads/master 53aa8316e -> 2f1519def
SPARK-2813: [SQL] Implement SQRT() directly in Spark SQL This PR adds a native implementation for SQL SQRT() and thus avoids delegating this function to Hive. Author: William Benton <wi...@redhat.com> Closes #1750 from willb/spark-2813 and squashes the following commits: 22c8a79 [William Benton] Fixed missed newline from rebase d673861 [William Benton] Added string coercions for SQRT and associated test case e125df4 [William Benton] Added ExpressionEvaluationSuite test cases for SQRT 7b84bcd [William Benton] SQL SQRT now properly returns NULL for NULL inputs 8256971 [William Benton] added SQRT test to SqlQuerySuite 504d2e5 [William Benton] Added native SQRT implementation Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2f1519de Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2f1519de Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2f1519de Branch: refs/heads/master Commit: 2f1519defaba4f3c7d536669f909bfd9e13e4069 Parents: 53aa831 Author: William Benton <wi...@redhat.com> Authored: Fri Aug 29 15:26:59 2014 -0700 Committer: Michael Armbrust <mich...@databricks.com> Committed: Fri Aug 29 15:26:59 2014 -0700 ---------------------------------------------------------------------- .../org/apache/spark/sql/catalyst/SqlParser.scala | 2 ++ .../sql/catalyst/analysis/HiveTypeCoercion.scala | 2 ++ .../spark/sql/catalyst/expressions/arithmetic.scala | 13 +++++++++++++ .../expressions/ExpressionEvaluationSuite.scala | 13 +++++++++++++ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 14 ++++++++++++++ .../main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 ++ 6 files changed, 46 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2f1519de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 2c73a80..4f166c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -122,6 +122,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val EXCEPT = Keyword("EXCEPT") protected val SUBSTR = Keyword("SUBSTR") protected val SUBSTRING = Keyword("SUBSTRING") + protected val SQRT = Keyword("SQRT") // Use reflection to find the reserved words defined in this class. protected val reservedWords = @@ -323,6 +324,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l) } | + SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ { case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs) } http://git-wip-us.apache.org/repos/asf/spark/blob/2f1519de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 15eb598..ecfcd62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -227,6 +227,8 @@ trait HiveTypeCoercion { Sum(Cast(e, DoubleType)) case Average(e) if e.dataType == StringType => Average(Cast(e, DoubleType)) + case Sqrt(e) if e.dataType == StringType => + Sqrt(Cast(e, DoubleType)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/2f1519de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index aae86a3..56f0428 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -33,6 +33,19 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { } } +case class Sqrt(child: Expression) extends UnaryExpression { + type EvaluatedType = Any + + def dataType = child.dataType + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"SQRT($child)" + + override def eval(input: Row): Any = { + n1(child, input, ((na,a) => math.sqrt(na.toDouble(a)))) + } +} + abstract class BinaryArithmetic extends BinaryExpression { self: Product => http://git-wip-us.apache.org/repos/asf/spark/blob/2f1519de/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index f1df817..b961346 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -577,4 +577,17 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(s.substring(0, 2), "ex", row) checkEvaluation(s.substring(0), "example", row) } + + test("SQRT") { + val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) + val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) + val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) + val d = 'a.double.at(0) + + for ((row, expected) <- rowSequence zip expectedResults) { + checkEvaluation(Sqrt(d), expected, row) + } + + checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null))) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/2f1519de/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9b2a36d..4047bc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -34,6 +34,20 @@ class SQLQuerySuite extends QueryTest { "test") } + test("SQRT") { + checkAnswer( + sql("SELECT SQRT(key) FROM testData"), + (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq + ) + } + + test("SQRT with automatic string casts") { + checkAnswer( + sql("SELECT SQRT(CAST(key AS STRING)) FROM testData"), + (1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq + ) + } + test("SPARK-2407 Added Parser of SQL SUBSTR()") { checkAnswer( sql("SELECT substr(tableName, 1, 2) FROM tableName"), http://git-wip-us.apache.org/repos/asf/spark/blob/2f1519de/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index fa3adfd..a4dd6be 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -889,6 +889,7 @@ private[hive] object HiveQl { val WHEN = "(?i)WHEN".r val CASE = "(?i)CASE".r val SUBSTR = "(?i)SUBSTR(?:ING)?".r + val SQRT = "(?i)SQRT".r protected def nodeToExpr(node: Node): Expression = node match { /* Attribute References */ @@ -958,6 +959,7 @@ private[hive] object HiveQl { case Token(DIV(), left :: right:: Nil) => Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) + case Token("TOK_FUNCTION", Token(SQRT(), Nil) :: arg :: Nil) => Sqrt(nodeToExpr(arg)) /* Comparisons */ case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org