Repository: spark
Updated Branches:
  refs/heads/branch-2.0 969313bb2 -> 2daab33c4


[SPARK-16714][SPARK-16735][SPARK-16646] array, map, greatest, least's type 
coercion should handle decimal type

## What changes were proposed in this pull request?

Here is a table about the behaviours of `array`/`map` and `greatest`/`least` in 
Hive, MySQL and Postgres:

|    |Hive|MySQL|Postgres|
|---|---|---|---|---|
|`array`/`map`|can find a wider type with decimal type arguments, and will 
truncate the wider decimal type if necessary|can find a wider type with decimal 
type arguments, no truncation problem|can find a wider type with decimal type 
arguments, no truncation problem|
|`greatest`/`least`|can find a wider type with decimal type arguments, and 
truncate if necessary, but can't do string promotion|can find a wider type with 
decimal type arguments, no truncation problem, but can't do string 
promotion|can find a wider type with decimal type arguments, no truncation 
problem, but can't do string promotion|

I think these behaviours makes sense and Spark SQL should follow them.

This PR fixes `array` and `map` by using `findWiderCommonType` to get the wider 
type.
This PR fixes `greatest` and `least` by add a 
`findWiderTypeWithoutStringPromotion`, which provides similar semantic of 
`findWiderCommonType`, but without string promotion.

## How was this patch tested?

new tests in `TypeCoersionSuite`

Author: Wenchen Fan <wenc...@databricks.com>
Author: Yin Huai <yh...@databricks.com>

Closes #14439 from cloud-fan/bug.

(cherry picked from commit b55f34370f695de355b72c1518b5f2a45c324af0)
Signed-off-by: Yin Huai <yh...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2daab33c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2daab33c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2daab33c

Branch: refs/heads/branch-2.0
Commit: 2daab33c4ed2cdbe1025252ae10bf19320df9d25
Parents: 969313b
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Wed Aug 3 11:15:09 2016 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Wed Aug 3 11:15:32 2016 -0700

----------------------------------------------------------------------
 .../sql/catalyst/analysis/TypeCoercion.scala    | 47 +++++++++-----
 .../analysis/ExpressionTypeCheckingSuite.scala  |  1 -
 .../catalyst/analysis/TypeCoercionSuite.scala   | 67 ++++++++++++++++++++
 3 files changed, 97 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2daab33c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 8503b8d..021952e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -109,18 +109,6 @@ object TypeCoercion {
   }
 
   /**
-   * Similar to [[findTightestCommonType]], if can not find the 
TightestCommonType, try to use
-   * [[findTightestCommonTypeToString]] to find the TightestCommonType.
-   */
-  private def findTightestCommonTypeAndPromoteToString(types: Seq[DataType]): 
Option[DataType] = {
-    types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
-      case None => None
-      case Some(d) =>
-        findTightestCommonTypeToString(d, c)
-    })
-  }
-
-  /**
    * Find the tightest common type of a set of types by continuously applying
    * `findTightestCommonTypeOfTwo` on these types.
    */
@@ -157,6 +145,28 @@ object TypeCoercion {
     })
   }
 
+  /**
+   * Similar to [[findWiderCommonType]], but can't promote to string. This is 
also similar to
+   * [[findTightestCommonType]], but can handle decimal types. If the wider 
decimal type exceeds
+   * system limitation, this rule will truncate the decimal type before return 
it.
+   */
+  private def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): 
Option[DataType] = {
+    types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
+      case Some(d) => findTightestCommonTypeOfTwo(d, c).orElse((d, c) match {
+        case (t1: DecimalType, t2: DecimalType) =>
+          Some(DecimalPrecision.widerDecimalType(t1, t2))
+        case (t: IntegralType, d: DecimalType) =>
+          Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+        case (d: DecimalType, t: IntegralType) =>
+          Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
+        case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: 
FractionalType) =>
+          Some(DoubleType)
+        case _ => None
+      })
+      case None => None
+    })
+  }
+
   private def haveSameType(exprs: Seq[Expression]): Boolean =
     exprs.map(_.dataType).distinct.length == 1
 
@@ -440,7 +450,7 @@ object TypeCoercion {
 
       case a @ CreateArray(children) if !haveSameType(children) =>
         val types = children.map(_.dataType)
-        findTightestCommonTypeAndPromoteToString(types) match {
+        findWiderCommonType(types) match {
           case Some(finalDataType) => CreateArray(children.map(Cast(_, 
finalDataType)))
           case None => a
         }
@@ -451,7 +461,7 @@ object TypeCoercion {
           m.keys
         } else {
           val types = m.keys.map(_.dataType)
-          findTightestCommonTypeAndPromoteToString(types) match {
+          findWiderCommonType(types) match {
             case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
             case None => m.keys
           }
@@ -461,7 +471,7 @@ object TypeCoercion {
           m.values
         } else {
           val types = m.values.map(_.dataType)
-          findTightestCommonTypeAndPromoteToString(types) match {
+          findWiderCommonType(types) match {
             case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
             case None => m.values
           }
@@ -494,16 +504,19 @@ object TypeCoercion {
           case None => c
         }
 
+      // When finding wider type for `Greatest` and `Least`, we should handle 
decimal types even if
+      // we need to truncate, but we should not promote one side to string if 
the other side is
+      // string.g
       case g @ Greatest(children) if !haveSameType(children) =>
         val types = children.map(_.dataType)
-        findTightestCommonType(types) match {
+        findWiderTypeWithoutStringPromotion(types) match {
           case Some(finalDataType) => Greatest(children.map(Cast(_, 
finalDataType)))
           case None => g
         }
 
       case l @ Least(children) if !haveSameType(children) =>
         val types = children.map(_.dataType)
-        findTightestCommonType(types) match {
+        findWiderTypeWithoutStringPromotion(types) match {
           case Some(finalDataType) => Least(children.map(Cast(_, 
finalDataType)))
           case None => l
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/2daab33c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 76e42d9..3aefb3c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -216,7 +216,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
       assertError(operator(Seq('booleanField)), "requires at least 2 
arguments")
       assertError(operator(Seq('intField, 'stringField)), "should all have the 
same type")
-      assertError(operator(Seq('intField, 'decimalField)), "should all have 
the same type")
       assertError(operator(Seq('mapField, 'mapField)), "does not support 
ordering")
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2daab33c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 971c99b..a13c45f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -283,6 +283,24 @@ class TypeCoercionSuite extends PlanTest {
         :: Cast(Literal(1), StringType)
         :: Cast(Literal("a"), StringType)
         :: Nil))
+
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      CreateArray(Literal.create(null, DecimalType(5, 3))
+        :: Literal(1)
+        :: Nil),
+      CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(13, 
3))
+        :: Literal(1).cast(DecimalType(13, 3))
+        :: Nil))
+
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      CreateArray(Literal.create(null, DecimalType(5, 3))
+        :: Literal.create(null, DecimalType(22, 10))
+        :: Literal.create(null, DecimalType(38, 38))
+        :: Nil),
+      CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 
38))
+        :: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38))
+        :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
+        :: Nil))
   }
 
   test("CreateMap casts") {
@@ -298,6 +316,17 @@ class TypeCoercionSuite extends PlanTest {
         :: Cast(Literal.create(2.0, FloatType), FloatType)
         :: Literal("b")
         :: Nil))
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      CreateMap(Literal.create(null, DecimalType(5, 3))
+        :: Literal("a")
+        :: Literal.create(2.0, FloatType)
+        :: Literal("b")
+        :: Nil),
+      CreateMap(Literal.create(null, DecimalType(5, 3)).cast(DoubleType)
+        :: Literal("a")
+        :: Literal.create(2.0, FloatType).cast(DoubleType)
+        :: Literal("b")
+        :: Nil))
     // type coercion for map values
     ruleTest(TypeCoercion.FunctionArgumentConversion,
       CreateMap(Literal(1)
@@ -310,6 +339,17 @@ class TypeCoercionSuite extends PlanTest {
         :: Literal(2)
         :: Cast(Literal(3.0), StringType)
         :: Nil))
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      CreateMap(Literal(1)
+        :: Literal.create(null, DecimalType(38, 0))
+        :: Literal(2)
+        :: Literal.create(null, DecimalType(38, 38))
+        :: Nil),
+      CreateMap(Literal(1)
+        :: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38))
+        :: Literal(2)
+        :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
+        :: Nil))
     // type coercion for both map keys and values
     ruleTest(TypeCoercion.FunctionArgumentConversion,
       CreateMap(Literal(1)
@@ -344,6 +384,33 @@ class TypeCoercionSuite extends PlanTest {
           :: Cast(Literal(1), DecimalType(22, 0))
           :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), 
DecimalType(22, 0))
           :: Nil))
+      ruleTest(TypeCoercion.FunctionArgumentConversion,
+        operator(Literal(1.0)
+          :: Literal.create(null, DecimalType(10, 5))
+          :: Literal(1)
+          :: Nil),
+        operator(Literal(1.0).cast(DoubleType)
+          :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType)
+          :: Literal(1).cast(DoubleType)
+          :: Nil))
+      ruleTest(TypeCoercion.FunctionArgumentConversion,
+        operator(Literal.create(null, DecimalType(15, 0))
+          :: Literal.create(null, DecimalType(10, 5))
+          :: Literal(1)
+          :: Nil),
+        operator(Literal.create(null, DecimalType(15, 0)).cast(DecimalType(20, 
5))
+          :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5))
+          :: Literal(1).cast(DecimalType(20, 5))
+          :: Nil))
+      ruleTest(TypeCoercion.FunctionArgumentConversion,
+        operator(Literal.create(2L, LongType)
+          :: Literal(1)
+          :: Literal.create(null, DecimalType(10, 5))
+          :: Nil),
+        operator(Literal.create(2L, LongType).cast(DecimalType(25, 5))
+          :: Literal(1).cast(DecimalType(25, 5))
+          :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(25, 5))
+          :: Nil))
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to