Repository: spark
Updated Branches:
  refs/heads/master fe2fb7fb7 -> c4fd2a242


[SPARK-9759] [SQL] improve decimal.times() and cast(int, decimalType)

This patch optimize two things:

1. passing MathContext to JavaBigDecimal.multiply/divide/reminder to do right 
rounding, because java.math.BigDecimal.apply(MathContext) is expensive

2. Cast integer/short/byte to decimal directly (without double)

This two optimizations could speed up the end-to-end time of a aggregation 
(SUM(short * decimal(5, 2)) 75% (from 19s -> 10.8s)

Author: Davies Liu <dav...@databricks.com>

Closes #8052 from davies/optimize_decimal and squashes the following commits:

225efad [Davies Liu] improve decimal.times() and cast(int, decimalType)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c4fd2a24
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c4fd2a24
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c4fd2a24

Branch: refs/heads/master
Commit: c4fd2a242228ee101904770446e3f37d49e39b76
Parents: fe2fb7f
Author: Davies Liu <dav...@databricks.com>
Authored: Mon Aug 10 13:55:11 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon Aug 10 13:55:11 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   | 42 +++++++-------------
 .../org/apache/spark/sql/types/Decimal.scala    | 12 +++---
 2 files changed, 22 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c4fd2a24/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 946c5a9..616b9e0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -155,7 +155,7 @@ case class Cast(child: Expression, dataType: DataType)
     case ByteType =>
       buildCast[Byte](_, _ != 0)
     case DecimalType() =>
-      buildCast[Decimal](_, _ != Decimal.ZERO)
+      buildCast[Decimal](_, !_.isZero)
     case DoubleType =>
       buildCast[Double](_, _ != 0)
     case FloatType =>
@@ -315,13 +315,13 @@ case class Cast(child: Expression, dataType: DataType)
     case TimestampType =>
       // Note that we lose precision here.
       buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), 
target))
-    case DecimalType() =>
+    case dt: DecimalType =>
       b => changePrecision(b.asInstanceOf[Decimal].clone(), target)
-    case LongType =>
-      b => changePrecision(Decimal(b.asInstanceOf[Long]), target)
-    case x: NumericType => // All other numeric types can be represented 
precisely as Doubles
+    case t: IntegralType =>
+      b => 
changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), 
target)
+    case x: FractionalType =>
       b => try {
-        
changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), 
target)
+        
changePrecision(Decimal(x.fractional.asInstanceOf[Fractional[Any]].toDouble(b)),
 target)
       } catch {
         case _: NumberFormatException => null
       }
@@ -534,10 +534,7 @@ case class Cast(child: Expression, dataType: DataType)
         (c, evPrim, evNull) =>
           s"""
             try {
-              org.apache.spark.sql.types.Decimal tmpDecimal =
-                new org.apache.spark.sql.types.Decimal().set(
-                  new scala.math.BigDecimal(
-                    new java.math.BigDecimal($c.toString())));
+              Decimal tmpDecimal = Decimal.apply(new 
java.math.BigDecimal($c.toString()));
               ${changePrecision("tmpDecimal", target, evPrim, evNull)}
             } catch (java.lang.NumberFormatException e) {
               $evNull = true;
@@ -546,12 +543,7 @@ case class Cast(child: Expression, dataType: DataType)
       case BooleanType =>
         (c, evPrim, evNull) =>
           s"""
-            org.apache.spark.sql.types.Decimal tmpDecimal = null;
-            if ($c) {
-              tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1);
-            } else {
-              tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0);
-            }
+            Decimal tmpDecimal = $c ? Decimal.apply(1) : Decimal.apply(0);
             ${changePrecision("tmpDecimal", target, evPrim, evNull)}
           """
       case DateType =>
@@ -561,32 +553,28 @@ case class Cast(child: Expression, dataType: DataType)
         // Note that we lose precision here.
         (c, evPrim, evNull) =>
           s"""
-            org.apache.spark.sql.types.Decimal tmpDecimal =
-              new org.apache.spark.sql.types.Decimal().set(
-                scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
+            Decimal tmpDecimal = Decimal.apply(
+              scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
             ${changePrecision("tmpDecimal", target, evPrim, evNull)}
           """
       case DecimalType() =>
         (c, evPrim, evNull) =>
           s"""
-            org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone();
+            Decimal tmpDecimal = $c.clone();
             ${changePrecision("tmpDecimal", target, evPrim, evNull)}
           """
-      case LongType =>
+      case x: IntegralType =>
         (c, evPrim, evNull) =>
           s"""
-            org.apache.spark.sql.types.Decimal tmpDecimal =
-              new org.apache.spark.sql.types.Decimal().set($c);
+            Decimal tmpDecimal = Decimal.apply((long) $c);
             ${changePrecision("tmpDecimal", target, evPrim, evNull)}
           """
-      case x: NumericType =>
+      case x: FractionalType =>
         // All other numeric types can be represented precisely as Doubles
         (c, evPrim, evNull) =>
           s"""
             try {
-              org.apache.spark.sql.types.Decimal tmpDecimal =
-                new org.apache.spark.sql.types.Decimal().set(
-                  scala.math.BigDecimal.valueOf((double) $c));
+              Decimal tmpDecimal = 
Decimal.apply(scala.math.BigDecimal.valueOf((double) $c));
               ${changePrecision("tmpDecimal", target, evPrim, evNull)}
             } catch (java.lang.NumberFormatException e) {
               $evNull = true;

http://git-wip-us.apache.org/repos/asf/spark/blob/c4fd2a24/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 624c3f3..d95805c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -139,9 +139,9 @@ final class Decimal extends Ordered[Decimal] with 
Serializable {
 
   def toBigDecimal: BigDecimal = {
     if (decimalVal.ne(null)) {
-      decimalVal(MATH_CONTEXT)
+      decimalVal
     } else {
-      BigDecimal(longVal, _scale)(MATH_CONTEXT)
+      BigDecimal(longVal, _scale)
     }
   }
 
@@ -280,13 +280,15 @@ final class Decimal extends Ordered[Decimal] with 
Serializable {
   }
 
   // HiveTypeCoercion will take care of the precision, scale of result
-  def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal)
+  def * (that: Decimal): Decimal =
+    Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT))
 
   def / (that: Decimal): Decimal =
-    if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal)
+    if (that.isZero) null else 
Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, MATH_CONTEXT))
 
   def % (that: Decimal): Decimal =
-    if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal)
+    if (that.isZero) null
+    else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, 
MATH_CONTEXT))
 
   def remainder(that: Decimal): Decimal = this % that
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to