Repository: spark
Updated Branches:
  refs/heads/master ff876137f -> d03e0af80


[SPARK-25522][SQL] Improve type promotion for input arguments of elementAt 
function

## What changes were proposed in this pull request?
In ElementAt, when first argument is MapType, we should coerce the key type and 
the second argument based on findTightestCommonType. This is not happening 
currently. We may produce wrong output as we will incorrectly downcast the 
right hand side double expression to int.

```SQL
spark-sql> select element_at(map(1,"one", 2, "two"), 2.2);

two
```

Also, when the first argument is ArrayType, the second argument should be an 
integer type or a smaller integral type that can be safely casted to an integer 
type. Currently we may do an unsafe cast. In the following case, we should fail 
with an error as 2.2 is not a integer index. But instead we down cast it to int 
currently and return a result instead.

```SQL
spark-sql> select element_at(array(1,2), 1.24D);

1
```
This PR also supports implicit cast between two MapTypes. I have followed 
similar logic that exists today to do implicit casts between two array types.
## How was this patch tested?
Added new tests in DataFrameFunctionSuite, TypeCoercionSuite.

Closes #22544 from dilipbiswal/SPARK-25522.

Authored-by: Dilip Biswal <dbis...@us.ibm.com>
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d03e0af8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d03e0af8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d03e0af8

Branch: refs/heads/master
Commit: d03e0af80d7659f12821cc2442efaeaee94d3985
Parents: ff87613
Author: Dilip Biswal <dbis...@us.ibm.com>
Authored: Thu Sep 27 15:04:59 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Thu Sep 27 15:04:59 2018 +0800

----------------------------------------------------------------------
 .../sql/catalyst/analysis/TypeCoercion.scala    | 19 +++++
 .../spark/sql/catalyst/expressions/Cast.scala   |  2 +-
 .../expressions/collectionOperations.scala      | 37 ++++++----
 .../catalyst/analysis/TypeCoercionSuite.scala   | 43 +++++++++--
 .../spark/sql/DataFrameFunctionsSuite.scala     | 75 +++++++++++++++++++-
 5 files changed, 154 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d03e0af8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 49d286f..72ac80e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -950,6 +950,25 @@ object TypeCoercion {
             if !Cast.forceNullable(fromType, toType) =>
           implicitCast(fromType, toType).map(ArrayType(_, false)).orNull
 
+        // Implicit cast between Map types.
+        // Follows the same semantics of implicit casting between two array 
types.
+        // Refer to documentation above. Make sure that both key and values
+        // can not be null after the implicit cast operation by calling 
forceNullable
+        // method.
+        case (MapType(fromKeyType, fromValueType, fn), MapType(toKeyType, 
toValueType, tn))
+            if !Cast.forceNullable(fromKeyType, toKeyType) && 
Cast.resolvableNullability(fn, tn) =>
+          if (Cast.forceNullable(fromValueType, toValueType) && !tn) {
+            null
+          } else {
+            val newKeyType = implicitCast(fromKeyType, toKeyType).orNull
+            val newValueType = implicitCast(fromValueType, toValueType).orNull
+            if (newKeyType != null && newValueType != null) {
+              MapType(newKeyType, newValueType, tn)
+            } else {
+              null
+            }
+          }
+
         case _ => null
       }
       Option(ret)

http://git-wip-us.apache.org/repos/asf/spark/blob/d03e0af8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8f77799..ee463bf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -183,7 +183,7 @@ object Cast {
     case _ => false
   }
 
-  private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
+  def resolvableNullability(from: Boolean, to: Boolean): Boolean = !from || to
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/d03e0af8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 9cc7dba..b24d748 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -2154,21 +2154,34 @@ case class ElementAt(left: Expression, right: 
Expression) extends GetMapValueUti
   }
 
   override def inputTypes: Seq[AbstractDataType] = {
-    Seq(TypeCollection(ArrayType, MapType),
-      left.dataType match {
-        case _: ArrayType => IntegerType
-        case _: MapType => mapKeyType
-        case _ => AnyDataType // no match for a wrong 'left' expression type
-      }
-    )
+    (left.dataType, right.dataType) match {
+      case (arr: ArrayType, e2: IntegralType) if (e2 != LongType) =>
+        Seq(arr, IntegerType)
+      case (MapType(keyType, valueType, hasNull), e2) =>
+        TypeCoercion.findTightestCommonType(keyType, e2) match {
+          case Some(dt) => Seq(MapType(dt, valueType, hasNull), dt)
+          case _ => Seq.empty
+        }
+      case (l, r) => Seq.empty
+
+    }
   }
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    super.checkInputDataTypes() match {
-      case f: TypeCheckResult.TypeCheckFailure => f
-      case TypeCheckResult.TypeCheckSuccess if 
left.dataType.isInstanceOf[MapType] =>
-        TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName")
-      case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess
+    (left.dataType, right.dataType) match {
+      case (_: ArrayType, e2) if e2 != IntegerType =>
+        TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName 
should have " +
+          s"been ${ArrayType.simpleString} followed by a 
${IntegerType.simpleString}, but it's " +
+          s"[${left.dataType.catalogString}, 
${right.dataType.catalogString}].")
+      case (MapType(e1, _, _), e2) if (!e2.sameType(e1)) =>
+        TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName 
should have " +
+          s"been ${MapType.simpleString} followed by a value of same key type, 
but it's " +
+          s"[${left.dataType.catalogString}, 
${right.dataType.catalogString}].")
+      case (e1, _) if (!e1.isInstanceOf[MapType] && 
!e1.isInstanceOf[ArrayType]) =>
+        TypeCheckResult.TypeCheckFailure(s"The first argument to function 
$prettyName should " +
+          s"have been ${ArrayType.simpleString} or ${MapType.simpleString} 
type, but its " +
+          s"${left.dataType.catalogString} type.")
+      case _ => TypeCheckResult.TypeCheckSuccess
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d03e0af8/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 0594673..0eba1c5 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -257,12 +257,43 @@ class TypeCoercionSuite extends AnalysisTest {
     shouldNotCast(checkedType, IntegralType)
   }
 
-  test("implicit type cast - MapType(StringType, StringType)") {
-    val checkedType = MapType(StringType, StringType)
-    checkTypeCasting(checkedType, castableTypes = Seq(checkedType))
-    shouldNotCast(checkedType, DecimalType)
-    shouldNotCast(checkedType, NumericType)
-    shouldNotCast(checkedType, IntegralType)
+  test("implicit type cast between two Map types") {
+    val sourceType = MapType(IntegerType, IntegerType, true)
+    val castableTypes = numericTypes ++ 
Seq(StringType).filter(!Cast.forceNullable(IntegerType, _))
+    val targetTypes = numericTypes.filter(!Cast.forceNullable(IntegerType, 
_)).map { t =>
+      MapType(t, sourceType.valueType, valueContainsNull = true)
+    }
+    val nonCastableTargetTypes = 
allTypes.filterNot(castableTypes.contains(_)).map {t =>
+      MapType(t, sourceType.valueType, valueContainsNull = true)
+    }
+
+    // Tests that its possible to setup implicit casts between two map types 
when
+    // source map's key type is integer and the target map's key type are 
either Byte, Short,
+    // Long, Double, Float, Decimal(38, 18) or String.
+    targetTypes.foreach { targetType =>
+      shouldCast(sourceType, targetType, targetType)
+    }
+
+    // Tests that its not possible to setup implicit casts between two map 
types when
+    // source map's key type is integer and the target map's key type are 
either Binary,
+    // Boolean, Date, Timestamp, Array, Struct, CaleandarIntervalType or 
NullType
+    nonCastableTargetTypes.foreach { targetType =>
+      shouldNotCast(sourceType, targetType)
+    }
+
+    // Tests that its not possible to cast from nullable map type to not 
nullable map type.
+    val targetNotNullableTypes = allTypes.filterNot(_ == IntegerType).map { t 
=>
+      MapType(t, sourceType.valueType, valueContainsNull = false)
+    }
+    val sourceMapExprWithValueNull =
+      CreateMap(Seq(Literal.default(sourceType.keyType),
+        Literal.create(null, sourceType.valueType)))
+    targetNotNullableTypes.foreach { targetType =>
+      val castDefault =
+        
TypeCoercion.ImplicitTypeCasts.implicitCast(sourceMapExprWithValueNull, 
targetType)
+      assert(castDefault.isEmpty,
+        s"Should not be able to cast $sourceType to $targetType, but got 
$castDefault")
+    }
   }
 
   test("implicit type cast - StructType().add(\"a1\", StringType)") {

http://git-wip-us.apache.org/repos/asf/spark/blob/d03e0af8/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 88dbae8..60ebc5e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1211,11 +1211,80 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
       Seq(Row("3"), Row(""), Row(null))
     )
 
-    val e = intercept[AnalysisException] {
+    val e1 = intercept[AnalysisException] {
       Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)")
     }
-    assert(e.message.contains(
-      "argument 1 requires (array or map) type, however, '`_1`' is of string 
type"))
+    val errorMsg1 =
+      s"""
+         |The first argument to function element_at should have been array or 
map type, but
+         |its string type.
+       """.stripMargin.replace("\n", " ").trim()
+    assert(e1.message.contains(errorMsg1))
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(array(2, 1), 2S)"),
+      Seq(Row(1))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(array('a', 'b'), 1Y)"),
+      Seq(Row("a"))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(array(1, 2, 3), 3)"),
+      Seq(Row(3))
+    )
+
+    val e2 = intercept[AnalysisException] {
+      OneRowRelation().selectExpr("element_at(array('a', 'b'), 1L)")
+    }
+    val errorMsg2 =
+      s"""
+         |Input to function element_at should have been array followed by a 
int, but it's
+         |[array<string>, bigint].
+       """.stripMargin.replace("\n", " ").trim()
+    assert(e2.message.contains(errorMsg2))
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2Y)"),
+      Seq(Row("b"))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1S)"),
+      Seq(Row("a"))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2)"),
+      Seq(Row("b"))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2L)"),
+      Seq(Row("b"))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.0D)"),
+      Seq(Row("a"))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.23D)"),
+      Seq(Row(null))
+    )
+
+    val e3 = intercept[AnalysisException] {
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), '1')")
+    }
+    val errorMsg3 =
+      s"""
+         |Input to function element_at should have been map followed by a 
value of same
+         |key type, but it's [map<int,string>, string].
+       """.stripMargin.replace("\n", " ").trim()
+    assert(e3.message.contains(errorMsg3))
   }
 
   test("array_union functions") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to