Repository: spark Updated Branches: refs/heads/branch-2.0 88603bd4f -> 7ef1d1c61
[SPARK-16278][SPARK-16279][SQL] Implement map_keys/map_values SQL functions This PR adds `map_keys` and `map_values` SQL functions in order to remove Hive fallback. Pass the Jenkins tests including new testcases. Author: Dongjoon Hyun <dongj...@apache.org> Closes #13967 from dongjoon-hyun/SPARK-16278. (cherry picked from commit 54b27c1797fcd32b3f3e9d44e1a149ae396a61e6) Signed-off-by: Reynold Xin <r...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7ef1d1c6 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7ef1d1c6 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7ef1d1c6 Branch: refs/heads/branch-2.0 Commit: 7ef1d1c618100313dbbdb6f615d9f87ff67e895d Parents: 88603bd Author: Dongjoon Hyun <dongj...@apache.org> Authored: Sun Jul 3 16:59:40 2016 +0800 Committer: Reynold Xin <r...@databricks.com> Committed: Thu Jul 7 21:02:50 2016 -0700 ---------------------------------------------------------------------- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/collectionOperations.scala | 48 ++++++++++++++++++++ .../expressions/CollectionFunctionsSuite.scala | 13 ++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 16 +++++++ .../spark/sql/hive/HiveSessionCatalog.scala | 1 - 5 files changed, 79 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7ef1d1c6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 95be0d6..27c3a09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -170,6 +170,8 @@ object FunctionRegistry { expression[IsNotNull]("isnotnull"), expression[Least]("least"), expression[CreateMap]("map"), + expression[MapKeys]("map_keys"), + expression[MapValues]("map_values"), expression[CreateNamedStruct]("named_struct"), expression[NaNvl]("nanvl"), expression[NullIf]("nullif"), http://git-wip-us.apache.org/repos/asf/spark/blob/7ef1d1c6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c71cb73..2e8ea11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -44,6 +44,54 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType } /** + * Returns an unordered array containing the keys of the map. + */ +@ExpressionDescription( + usage = "_FUNC_(map) - Returns an unordered array containing the keys of the map.", + extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [1,2]") +case class MapKeys(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType) + + override def dataType: DataType = ArrayType(child.dataType.asInstanceOf[MapType].keyType) + + override def nullSafeEval(map: Any): Any = { + map.asInstanceOf[MapData].keyArray() + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).keyArray();") + } + + override def prettyName: String = "map_keys" +} + +/** + * Returns an unordered array containing the values of the map. + */ +@ExpressionDescription( + usage = "_FUNC_(map) - Returns an unordered array containing the values of the map.", + extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [\"a\",\"b\"]") +case class MapValues(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType) + + override def dataType: DataType = ArrayType(child.dataType.asInstanceOf[MapType].valueType) + + override def nullSafeEval(map: Any): Any = { + map.asInstanceOf[MapData].valueArray() + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).valueArray();") + } + + override def prettyName: String = "map_values" +} + +/** * Sorts the input array in ascending / descending order according to the natural ordering of * the array elements and returns it. */ http://git-wip-us.apache.org/repos/asf/spark/blob/7ef1d1c6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 1aae467..a5f784f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -44,6 +44,19 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, ArrayType(StringType)), null) } + test("MapKeys/MapValues") { + val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(MapKeys(m0), Seq("a", "b")) + checkEvaluation(MapValues(m0), Seq("1", "2")) + checkEvaluation(MapKeys(m1), Seq()) + checkEvaluation(MapValues(m1), Seq()) + checkEvaluation(MapKeys(m2), null) + checkEvaluation(MapValues(m2), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) http://git-wip-us.apache.org/repos/asf/spark/blob/7ef1d1c6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 73d7765..0f6c49e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -352,6 +352,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("map_keys/map_values function") { + val df = Seq( + (Map[Int, Int](1 -> 100, 2 -> 200), "x"), + (Map[Int, Int](), "y"), + (Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), "z") + ).toDF("a", "b") + checkAnswer( + df.selectExpr("map_keys(a)"), + Seq(Row(Seq(1, 2)), Row(Seq.empty), Row(Seq(1, 2, 3))) + ) + checkAnswer( + df.selectExpr("map_values(a)"), + Seq(Row(Seq(100, 200)), Row(Seq.empty), Row(Seq(100, 200, 300))) + ) + } + test("array contains function") { val df = Seq( (Seq[Int](1, 2), "x"), http://git-wip-us.apache.org/repos/asf/spark/blob/7ef1d1c6/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 1479554..4c986b0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -236,7 +236,6 @@ private[sql] class HiveSessionCatalog( // str_to_map, windowingtablefunction. private val hiveFunctions = Seq( "hash", "java_method", "histogram_numeric", - "map_keys", "map_values", "parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map", "xpath", "xpath_boolean", "xpath_double", "xpath_float", "xpath_int", "xpath_long", "xpath_number", "xpath_short", "xpath_string", --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org