Repository: spark
Updated Branches:
  refs/heads/branch-2.2 3686c2e96 -> f59f9a380


[SPARK-20876][SQL][BACKPORT-2.2] If the input parameter is float type for ceil 
or floor,the result is not we expected

## What changes were proposed in this pull request?

This PR is to backport #18103 to Spark 2.2

## How was this patch tested?
unit test

Author: liuxian <liu.xi...@zte.com.cn>

Closes #18155 from 10110346/wip-lx-0531.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f59f9a38
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f59f9a38
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f59f9a38

Branch: refs/heads/branch-2.2
Commit: f59f9a380351726de20453ab101f46e199a7079c
Parents: 3686c2e
Author: liuxian <liu.xi...@zte.com.cn>
Authored: Wed May 31 11:43:36 2017 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Wed May 31 11:43:36 2017 -0700

----------------------------------------------------------------------
 .../catalyst/expressions/mathExpressions.scala  | 16 ++++++++++------
 .../expressions/MathExpressionsSuite.scala      | 20 ++++++++++++++++++++
 2 files changed, 30 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f59f9a38/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index e040ad0..52f3af4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -232,18 +232,20 @@ case class Ceil(child: Expression) extends 
UnaryMathExpression(math.ceil, "CEIL"
   }
 
   override def inputTypes: Seq[AbstractDataType] =
-    Seq(TypeCollection(DoubleType, DecimalType))
+    Seq(TypeCollection(DoubleType, DecimalType, LongType))
 
   protected override def nullSafeEval(input: Any): Any = child.dataType match {
+    case LongType => input.asInstanceOf[Long]
     case DoubleType => f(input.asInstanceOf[Double]).toLong
-    case DecimalType.Fixed(precision, scale) => 
input.asInstanceOf[Decimal].ceil
+    case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].ceil
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     child.dataType match {
       case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
-      case DecimalType.Fixed(precision, scale) =>
+      case DecimalType.Fixed(_, _) =>
         defineCodeGen(ctx, ev, c => s"$c.ceil()")
+      case LongType => defineCodeGen(ctx, ev, c => s"$c")
       case _ => defineCodeGen(ctx, ev, c => 
s"(long)(java.lang.Math.${funcName}($c))")
     }
   }
@@ -347,18 +349,20 @@ case class Floor(child: Expression) extends 
UnaryMathExpression(math.floor, "FLO
   }
 
   override def inputTypes: Seq[AbstractDataType] =
-    Seq(TypeCollection(DoubleType, DecimalType))
+    Seq(TypeCollection(DoubleType, DecimalType, LongType))
 
   protected override def nullSafeEval(input: Any): Any = child.dataType match {
+    case LongType => input.asInstanceOf[Long]
     case DoubleType => f(input.asInstanceOf[Double]).toLong
-    case DecimalType.Fixed(precision, scale) => 
input.asInstanceOf[Decimal].floor
+    case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].floor
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     child.dataType match {
       case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
-      case DecimalType.Fixed(precision, scale) =>
+      case DecimalType.Fixed(_, _) =>
         defineCodeGen(ctx, ev, c => s"$c.floor()")
+      case LongType => defineCodeGen(ctx, ev, c => s"$c")
       case _ => defineCodeGen(ctx, ev, c => 
s"(long)(java.lang.Math.${funcName}($c))")
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/f59f9a38/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index 1555dd1..69ada82 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -252,6 +252,16 @@ class MathExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3))
     checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0))
     checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0))
+
+    val doublePi: Double = 3.1415
+    val floatPi: Float = 3.1415f
+    val longLit: Long = 12345678901234567L
+    checkEvaluation(Ceil(doublePi), 4L, EmptyRow)
+    checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow)
+    checkEvaluation(Ceil(longLit), longLit, EmptyRow)
+    checkEvaluation(Ceil(-doublePi), -3L, EmptyRow)
+    checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow)
+    checkEvaluation(Ceil(-longLit), -longLit, EmptyRow)
   }
 
   test("floor") {
@@ -262,6 +272,16 @@ class MathExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3))
     checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0))
     checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0))
+
+    val doublePi: Double = 3.1415
+    val floatPi: Float = 3.1415f
+    val longLit: Long = 12345678901234567L
+    checkEvaluation(Floor(doublePi), 3L, EmptyRow)
+    checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow)
+    checkEvaluation(Floor(longLit), longLit, EmptyRow)
+    checkEvaluation(Floor(-doublePi), -4L, EmptyRow)
+    checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow)
+    checkEvaluation(Floor(-longLit), -longLit, EmptyRow)
   }
 
   test("factorial") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to