Repository: spark
Updated Branches:
  refs/heads/master c7967c604 -> e07aee216


[SPARK-24636][SQL] Type coercion of arrays for array_join function

## What changes were proposed in this pull request?
Presto's implementation accepts arbitrary arrays of primitive types as an input:

```
presto> SELECT array_join(ARRAY [1, 2, 3], ', ');
_col0
---------
1, 2, 3
(1 row)
```

This PR proposes to implement a type coercion rule for ```array_join``` 
function that converts arrays of primitive as well as non-primitive types to 
arrays of string.

## How was this patch tested?

New test cases add into:
- sql-tests/inputs/typeCoercion/native/arrayJoin.sql
- DataFrameFunctionsSuite.scala

Author: Marek Novotny <mn.mi...@gmail.com>

Closes #21620 from mn-mikke/SPARK-24636.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e07aee21
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e07aee21
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e07aee21

Branch: refs/heads/master
Commit: e07aee2165af4d301ae12005a6d9ffb030bc2650
Parents: c7967c6
Author: Marek Novotny <mn.mi...@gmail.com>
Authored: Tue Jun 26 09:51:55 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Tue Jun 26 09:51:55 2018 +0800

----------------------------------------------------------------------
 .../sql/catalyst/analysis/TypeCoercion.scala    |  8 ++
 .../expressions/collectionOperations.scala      |  1 +
 .../inputs/typeCoercion/native/arrayJoin.sql    | 11 +++
 .../typeCoercion/native/arrayJoin.sql.out       | 90 ++++++++++++++++++++
 .../spark/sql/DataFrameFunctionsSuite.scala     | 17 ++++
 5 files changed, 127 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e07aee21/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index b2817b0..6379239 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -536,6 +536,14 @@ object TypeCoercion {
           case None => c
         }
 
+      case aj @ ArrayJoin(arr, d, nr) if 
!ArrayType(StringType).acceptsType(arr.dataType) &&
+        ArrayType.acceptsType(arr.dataType) =>
+        val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull
+        ImplicitTypeCasts.implicitCast(arr, ArrayType(StringType, 
containsNull)) match {
+          case Some(castedArr) => ArrayJoin(castedArr, d, nr)
+          case None => aj
+        }
+
       case m @ CreateMap(children) if m.keys.length == m.values.length &&
         (!haveSameType(m.keys) || !haveSameType(m.values)) =>
         val newKeys = if (haveSameType(m.keys)) {

http://git-wip-us.apache.org/repos/asf/spark/blob/e07aee21/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index b6137b0..58612f6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1621,6 +1621,7 @@ case class ArrayJoin(
 
   override def dataType: DataType = StringType
 
+  override def prettyName: String = "array_join"
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/e07aee21/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql
 
b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql
new file mode 100644
index 0000000..99729c0
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql
@@ -0,0 +1,11 @@
+SELECT array_join(array(true, false), ', ');
+SELECT array_join(array(2Y, 1Y), ', ');
+SELECT array_join(array(2S, 1S), ', ');
+SELECT array_join(array(2, 1), ', ');
+SELECT array_join(array(2L, 1L), ', ');
+SELECT array_join(array(9223372036854775809, 9223372036854775808), ', ');
+SELECT array_join(array(2.0D, 1.0D), ', ');
+SELECT array_join(array(float(2.0), float(1.0)), ', ');
+SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', ');
+SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp 
'2016-11-12 20:54:00.000'), ', ');
+SELECT array_join(array('a', 'b'), ', ');

http://git-wip-us.apache.org/repos/asf/spark/blob/e07aee21/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out
new file mode 100644
index 0000000..b23a62d
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out
@@ -0,0 +1,90 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 11
+
+
+-- !query 0
+SELECT array_join(array(true, false), ', ')
+-- !query 0 schema
+struct<array_join(array(true, false), , ):string>
+-- !query 0 output
+true, false
+
+
+-- !query 1
+SELECT array_join(array(2Y, 1Y), ', ')
+-- !query 1 schema
+struct<array_join(array(2, 1), , ):string>
+-- !query 1 output
+2, 1
+
+
+-- !query 2
+SELECT array_join(array(2S, 1S), ', ')
+-- !query 2 schema
+struct<array_join(array(2, 1), , ):string>
+-- !query 2 output
+2, 1
+
+
+-- !query 3
+SELECT array_join(array(2, 1), ', ')
+-- !query 3 schema
+struct<array_join(array(2, 1), , ):string>
+-- !query 3 output
+2, 1
+
+
+-- !query 4
+SELECT array_join(array(2L, 1L), ', ')
+-- !query 4 schema
+struct<array_join(array(2, 1), , ):string>
+-- !query 4 output
+2, 1
+
+
+-- !query 5
+SELECT array_join(array(9223372036854775809, 9223372036854775808), ', ')
+-- !query 5 schema
+struct<array_join(array(9223372036854775809, 9223372036854775808), , ):string>
+-- !query 5 output
+9223372036854775809, 9223372036854775808
+
+
+-- !query 6
+SELECT array_join(array(2.0D, 1.0D), ', ')
+-- !query 6 schema
+struct<array_join(array(2.0, 1.0), , ):string>
+-- !query 6 output
+2.0, 1.0
+
+
+-- !query 7
+SELECT array_join(array(float(2.0), float(1.0)), ', ')
+-- !query 7 schema
+struct<array_join(array(CAST(2.0 AS FLOAT), CAST(1.0 AS FLOAT)), , ):string>
+-- !query 7 output
+2.0, 1.0
+
+
+-- !query 8
+SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', ')
+-- !query 8 schema
+struct<array_join(array(DATE '2016-03-14', DATE '2016-03-13'), , ):string>
+-- !query 8 output
+2016-03-14, 2016-03-13
+
+
+-- !query 9
+SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp 
'2016-11-12 20:54:00.000'), ', ')
+-- !query 9 schema
+struct<array_join(array(TIMESTAMP('2016-11-15 20:54:00.0'), 
TIMESTAMP('2016-11-12 20:54:00.0')), , ):string>
+-- !query 9 output
+2016-11-15 20:54:00, 2016-11-12 20:54:00
+
+
+-- !query 10
+SELECT array_join(array('a', 'b'), ', ')
+-- !query 10 schema
+struct<array_join(array(a, b), , ):string>
+-- !query 10 output
+a, b

http://git-wip-us.apache.org/repos/asf/spark/blob/e07aee21/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 47fe67d..5d6a6c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -805,6 +805,23 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     checkAnswer(
       df.selectExpr("array_join(x, delimiter, 'NULL')"),
       Seq(Row("a,b"), Row("a,NULL,b"), Row("")))
+
+    val idf = Seq(Seq(1, 2, 3)).toDF("x")
+
+    checkAnswer(
+      idf.select(array_join(idf("x"), ", ")),
+      Seq(Row("1, 2, 3"))
+    )
+    checkAnswer(
+      idf.selectExpr("array_join(x, ', ')"),
+      Seq(Row("1, 2, 3"))
+    )
+    intercept[AnalysisException] {
+      idf.selectExpr("array_join(x, 1)")
+    }
+    intercept[AnalysisException] {
+      idf.selectExpr("array_join(x, ', ', 1)")
+    }
   }
 
   test("array_min function") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to