Repository: spark
Updated Branches:
  refs/heads/master ea990f969 -> 54b27c179


[SPARK-16278][SPARK-16279][SQL] Implement map_keys/map_values SQL functions

## What changes were proposed in this pull request?

This PR adds `map_keys` and `map_values` SQL functions in order to remove Hive 
fallback.

## How was this patch tested?

Pass the Jenkins tests including new testcases.

Author: Dongjoon Hyun <dongj...@apache.org>

Closes #13967 from dongjoon-hyun/SPARK-16278.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/54b27c17
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/54b27c17
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/54b27c17

Branch: refs/heads/master
Commit: 54b27c1797fcd32b3f3e9d44e1a149ae396a61e6
Parents: ea990f9
Author: Dongjoon Hyun <dongj...@apache.org>
Authored: Sun Jul 3 16:59:40 2016 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Sun Jul 3 16:59:40 2016 +0800

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |  2 +
 .../expressions/collectionOperations.scala      | 48 ++++++++++++++++++++
 .../expressions/CollectionFunctionsSuite.scala  | 13 ++++++
 .../spark/sql/DataFrameFunctionsSuite.scala     | 16 +++++++
 .../spark/sql/hive/HiveSessionCatalog.scala     |  1 -
 5 files changed, 79 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/54b27c17/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 26b0c30..e7f335f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -171,6 +171,8 @@ object FunctionRegistry {
     expression[IsNotNull]("isnotnull"),
     expression[Least]("least"),
     expression[CreateMap]("map"),
+    expression[MapKeys]("map_keys"),
+    expression[MapValues]("map_values"),
     expression[CreateNamedStruct]("named_struct"),
     expression[NaNvl]("nanvl"),
     expression[NullIf]("nullif"),

http://git-wip-us.apache.org/repos/asf/spark/blob/54b27c17/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index c71cb73..2e8ea11 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -44,6 +44,54 @@ case class Size(child: Expression) extends UnaryExpression 
with ExpectsInputType
 }
 
 /**
+ * Returns an unordered array containing the keys of the map.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(map) - Returns an unordered array containing the keys of the 
map.",
+  extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [1,2]")
+case class MapKeys(child: Expression)
+  extends UnaryExpression with ExpectsInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
+
+  override def dataType: DataType = 
ArrayType(child.dataType.asInstanceOf[MapType].keyType)
+
+  override def nullSafeEval(map: Any): Any = {
+    map.asInstanceOf[MapData].keyArray()
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).keyArray();")
+  }
+
+  override def prettyName: String = "map_keys"
+}
+
+/**
+ * Returns an unordered array containing the values of the map.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(map) - Returns an unordered array containing the values of 
the map.",
+  extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [\"a\",\"b\"]")
+case class MapValues(child: Expression)
+  extends UnaryExpression with ExpectsInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
+
+  override def dataType: DataType = 
ArrayType(child.dataType.asInstanceOf[MapType].valueType)
+
+  override def nullSafeEval(map: Any): Any = {
+    map.asInstanceOf[MapData].valueArray()
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).valueArray();")
+  }
+
+  override def prettyName: String = "map_values"
+}
+
+/**
  * Sorts the input array in ascending / descending order according to the 
natural ordering of
  * the array elements and returns it.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/54b27c17/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
index 1aae467..a5f784f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
@@ -44,6 +44,19 @@ class CollectionFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
   }
 
+  test("MapKeys/MapValues") {
+    val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, 
StringType))
+    val m1 = Literal.create(Map[String, String](), MapType(StringType, 
StringType))
+    val m2 = Literal.create(null, MapType(StringType, StringType))
+
+    checkEvaluation(MapKeys(m0), Seq("a", "b"))
+    checkEvaluation(MapValues(m0), Seq("1", "2"))
+    checkEvaluation(MapKeys(m1), Seq())
+    checkEvaluation(MapValues(m1), Seq())
+    checkEvaluation(MapKeys(m2), null)
+    checkEvaluation(MapValues(m2), null)
+  }
+
   test("Sort Array") {
     val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
     val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))

http://git-wip-us.apache.org/repos/asf/spark/blob/54b27c17/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 73d7765..0f6c49e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -352,6 +352,22 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     )
   }
 
+  test("map_keys/map_values function") {
+    val df = Seq(
+      (Map[Int, Int](1 -> 100, 2 -> 200), "x"),
+      (Map[Int, Int](), "y"),
+      (Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), "z")
+    ).toDF("a", "b")
+    checkAnswer(
+      df.selectExpr("map_keys(a)"),
+      Seq(Row(Seq(1, 2)), Row(Seq.empty), Row(Seq(1, 2, 3)))
+    )
+    checkAnswer(
+      df.selectExpr("map_values(a)"),
+      Seq(Row(Seq(100, 200)), Row(Seq.empty), Row(Seq(100, 200, 300)))
+    )
+  }
+
   test("array contains function") {
     val df = Seq(
       (Seq[Int](1, 2), "x"),

http://git-wip-us.apache.org/repos/asf/spark/blob/54b27c17/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 1fffadb..53990b8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -239,7 +239,6 @@ private[sql] class HiveSessionCatalog(
   // str_to_map, windowingtablefunction.
   private val hiveFunctions = Seq(
     "hash", "java_method", "histogram_numeric",
-    "map_keys", "map_values",
     "parse_url", "percentile", "percentile_approx", "reflect", "sentences", 
"stack", "str_to_map",
     "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
     "xpath_number", "xpath_short", "xpath_string",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to