Repository: spark
Updated Branches:
  refs/heads/branch-2.0 88603bd4f -> 7ef1d1c61


[SPARK-16278][SPARK-16279][SQL] Implement map_keys/map_values SQL functions

This PR adds `map_keys` and `map_values` SQL functions in order to remove Hive 
fallback.

Pass the Jenkins tests including new testcases.

Author: Dongjoon Hyun <dongj...@apache.org>

Closes #13967 from dongjoon-hyun/SPARK-16278.

(cherry picked from commit 54b27c1797fcd32b3f3e9d44e1a149ae396a61e6)
Signed-off-by: Reynold Xin <r...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7ef1d1c6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7ef1d1c6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7ef1d1c6

Branch: refs/heads/branch-2.0
Commit: 7ef1d1c618100313dbbdb6f615d9f87ff67e895d
Parents: 88603bd
Author: Dongjoon Hyun <dongj...@apache.org>
Authored: Sun Jul 3 16:59:40 2016 +0800
Committer: Reynold Xin <r...@databricks.com>
Committed: Thu Jul 7 21:02:50 2016 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |  2 +
 .../expressions/collectionOperations.scala      | 48 ++++++++++++++++++++
 .../expressions/CollectionFunctionsSuite.scala  | 13 ++++++
 .../spark/sql/DataFrameFunctionsSuite.scala     | 16 +++++++
 .../spark/sql/hive/HiveSessionCatalog.scala     |  1 -
 5 files changed, 79 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7ef1d1c6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 95be0d6..27c3a09 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -170,6 +170,8 @@ object FunctionRegistry {
     expression[IsNotNull]("isnotnull"),
     expression[Least]("least"),
     expression[CreateMap]("map"),
+    expression[MapKeys]("map_keys"),
+    expression[MapValues]("map_values"),
     expression[CreateNamedStruct]("named_struct"),
     expression[NaNvl]("nanvl"),
     expression[NullIf]("nullif"),

http://git-wip-us.apache.org/repos/asf/spark/blob/7ef1d1c6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index c71cb73..2e8ea11 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -44,6 +44,54 @@ case class Size(child: Expression) extends UnaryExpression 
with ExpectsInputType
 }
 
 /**
+ * Returns an unordered array containing the keys of the map.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(map) - Returns an unordered array containing the keys of the 
map.",
+  extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [1,2]")
+case class MapKeys(child: Expression)
+  extends UnaryExpression with ExpectsInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
+
+  override def dataType: DataType = 
ArrayType(child.dataType.asInstanceOf[MapType].keyType)
+
+  override def nullSafeEval(map: Any): Any = {
+    map.asInstanceOf[MapData].keyArray()
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).keyArray();")
+  }
+
+  override def prettyName: String = "map_keys"
+}
+
+/**
+ * Returns an unordered array containing the values of the map.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(map) - Returns an unordered array containing the values of 
the map.",
+  extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [\"a\",\"b\"]")
+case class MapValues(child: Expression)
+  extends UnaryExpression with ExpectsInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
+
+  override def dataType: DataType = 
ArrayType(child.dataType.asInstanceOf[MapType].valueType)
+
+  override def nullSafeEval(map: Any): Any = {
+    map.asInstanceOf[MapData].valueArray()
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).valueArray();")
+  }
+
+  override def prettyName: String = "map_values"
+}
+
+/**
  * Sorts the input array in ascending / descending order according to the 
natural ordering of
  * the array elements and returns it.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/7ef1d1c6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
index 1aae467..a5f784f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
@@ -44,6 +44,19 @@ class CollectionFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
   }
 
+  test("MapKeys/MapValues") {
+    val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, 
StringType))
+    val m1 = Literal.create(Map[String, String](), MapType(StringType, 
StringType))
+    val m2 = Literal.create(null, MapType(StringType, StringType))
+
+    checkEvaluation(MapKeys(m0), Seq("a", "b"))
+    checkEvaluation(MapValues(m0), Seq("1", "2"))
+    checkEvaluation(MapKeys(m1), Seq())
+    checkEvaluation(MapValues(m1), Seq())
+    checkEvaluation(MapKeys(m2), null)
+    checkEvaluation(MapValues(m2), null)
+  }
+
   test("Sort Array") {
     val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
     val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))

http://git-wip-us.apache.org/repos/asf/spark/blob/7ef1d1c6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 73d7765..0f6c49e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -352,6 +352,22 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     )
   }
 
+  test("map_keys/map_values function") {
+    val df = Seq(
+      (Map[Int, Int](1 -> 100, 2 -> 200), "x"),
+      (Map[Int, Int](), "y"),
+      (Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), "z")
+    ).toDF("a", "b")
+    checkAnswer(
+      df.selectExpr("map_keys(a)"),
+      Seq(Row(Seq(1, 2)), Row(Seq.empty), Row(Seq(1, 2, 3)))
+    )
+    checkAnswer(
+      df.selectExpr("map_values(a)"),
+      Seq(Row(Seq(100, 200)), Row(Seq.empty), Row(Seq(100, 200, 300)))
+    )
+  }
+
   test("array contains function") {
     val df = Seq(
       (Seq[Int](1, 2), "x"),

http://git-wip-us.apache.org/repos/asf/spark/blob/7ef1d1c6/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 1479554..4c986b0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -236,7 +236,6 @@ private[sql] class HiveSessionCatalog(
   // str_to_map, windowingtablefunction.
   private val hiveFunctions = Seq(
     "hash", "java_method", "histogram_numeric",
-    "map_keys", "map_values",
     "parse_url", "percentile", "percentile_approx", "reflect", "sentences", 
"stack", "str_to_map",
     "xpath", "xpath_boolean", "xpath_double", "xpath_float", "xpath_int", 
"xpath_long",
     "xpath_number", "xpath_short", "xpath_string",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to