This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 747846bd3ef3 [SPARK-47007][SQL][PYTHON][R][CONNECT] Add the `map_sort` 
function
747846bd3ef3 is described below

commit 747846bd3ef38eaec204ae32e47bdcb192fd2797
Author: Stevo Mitric <stevo.mit...@databricks.com>
AuthorDate: Wed Mar 20 10:00:11 2024 +0500

    [SPARK-47007][SQL][PYTHON][R][CONNECT] Add the `map_sort` function
    
    ### What changes were proposed in this pull request?
    
    Adding a new function `map_sort` to:
    - Scala API
    - Python API
    - R API
    - Spark Connect Scala Client
    - Spark Connect Python Client
    
    ### Why are the changes needed?
    
    In order to add the ability to do GROUP BY on map types we first have to be 
able to sort the maps by their key
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, new function `map_sort`
    
    ### How was this patch tested?
    
    With new UTs
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45069 from stefankandic/SPARK-47007.
    
    Lead-authored-by: Stevo Mitric <stevo.mit...@databricks.com>
    Co-authored-by: Stefan Kandic <stefan.kan...@databricks.com>
    Co-authored-by: Stevo Mitric <stevomitric2...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 R/pkg/NAMESPACE                                    |   1 +
 R/pkg/R/functions.R                                |  17 ++
 R/pkg/R/generics.R                                 |   4 +
 R/pkg/tests/fulltests/test_sparkSQL.R              |   6 +
 .../scala/org/apache/spark/sql/functions.scala     |  17 ++
 .../apache/spark/sql/PlanGenerationTestSuite.scala |   4 +
 .../explain-results/function_map_sort.explain      |   2 +
 .../query-tests/queries/function_map_sort.json     |  29 ++++
 .../queries/function_map_sort.proto.bin            | Bin 0 -> 183 bytes
 .../source/reference/pyspark.sql/functions.rst     |   1 +
 python/pyspark/sql/connect/functions/builtin.py    |   7 +
 python/pyspark/sql/functions/builtin.py            |  48 ++++++
 python/pyspark/sql/tests/test_functions.py         |   7 +
 .../sql/catalyst/analysis/FunctionRegistry.scala   |   1 +
 .../expressions/collectionOperations.scala         | 172 +++++++++++++++++++++
 .../expressions/CollectionExpressionsSuite.scala   |  40 +++++
 .../scala/org/apache/spark/sql/functions.scala     |  17 ++
 .../sql-functions/sql-expression-schema.md         |   1 +
 .../apache/spark/sql/DataFrameFunctionsSuite.scala |  82 +++++++++-
 19 files changed, 455 insertions(+), 1 deletion(-)

diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 3d683ba919a9..a0aa7d0f42ff 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -363,6 +363,7 @@ exportMethods("%<=>%",
               "map_keys",
               "map_values",
               "map_zip_with",
+              "map_sort",
               "max",
               "max_by",
               "md5",
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index a7e337d3f9af..bb8085863482 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -4552,6 +4552,23 @@ setMethod("map_zip_with",
            )
           })
 
+#' @details
+#' \code{map_sort}: Sorts the input map in ascending or descending order 
according to
+#' the natural ordering of the map keys.
+#'
+#' @rdname column_collection_functions
+#' @param asc a logical flag indicating the sorting order.
+#'            TRUE, sorting is in ascending order.
+#'            FALSE, sorting is in descending order.
+#' @aliases map_sort map_sort,Column-method
+#' @note map_sort since 4.0.0
+setMethod("map_sort",
+          signature(x = "Column"),
+          function(x, asc = TRUE) {
+            jc <- callJStatic("org.apache.spark.sql.functions", "map_sort", 
x@jc, asc)
+            column(jc)
+          })
+
 #' @details
 #' \code{element_at}: Returns element of array at given index in 
\code{extraction} if
 #' \code{x} is array. Returns value for the given key in \code{extraction} if 
\code{x} is map.
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 10a85c7b891a..58bdd53eae25 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1224,6 +1224,10 @@ setGeneric("map_values", function(x) { 
standardGeneric("map_values") })
 #' @name NULL
 setGeneric("map_zip_with", function(x, y, f) { standardGeneric("map_zip_with") 
})
 
+#' @rdname column_collection_functions
+#' @name NULL
+setGeneric("map_sort", function(x, asc = TRUE) { standardGeneric("map_sort") })
+
 #' @rdname column_aggregate_functions
 #' @name NULL
 setGeneric("max_by", function(x, y) { standardGeneric("max_by") })
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R 
b/R/pkg/tests/fulltests/test_sparkSQL.R
index c44924e55087..540c46b6769f 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -1648,6 +1648,12 @@ test_that("column functions", {
   expected_entries <- list(as.environment(list(x = 1, y = 2, a = 3, b = 4)))
   expect_equal(result, expected_entries)
 
+  # Test map_sort
+  df <- createDataFrame(list(list(map1 = as.environment(list(c = 3, a = 1, b = 
2)))))
+  result <- collect(select(df, map_sort(df[[1]])))[[1]]
+  expected_entries <- list(as.environment(list(a = 1, b = 2, c = 3)))
+  expect_equal(result, expected_entries)
+
   # Test map_entries(), map_keys(), map_values() and element_at()
   df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2)))))
   result <- collect(select(df, map_entries(df$map)))[[1]]
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 7610a234ecd9..affebf3ae043 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -7097,6 +7097,23 @@ object functions {
    */
   def sort_array(e: Column, asc: Boolean): Column = Column.fn("sort_array", e, 
lit(asc))
 
+  /**
+   * Sorts the input map in ascending order according to the natural ordering 
of the map keys.
+   *
+   * @group map_funcs
+   * @since 4.0.0
+   */
+  def map_sort(e: Column): Column = map_sort(e, asc = true)
+
+  /**
+   * Sorts the input map in ascending or descending order according to the 
natural ordering of the
+   * map keys.
+   *
+   * @group map_funcs
+   * @since 4.0.0
+   */
+  def map_sort(e: Column, asc: Boolean): Column = Column.fn("map_sort", e, 
lit(asc))
+
   /**
    * Returns the minimum value in the array. NaN is greater than any non-NaN 
elements for
    * double/float type. NULL elements are skipped.
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 46789057ed3c..94fe30059136 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -2533,6 +2533,10 @@ class PlanGenerationTestSuite
     fn.map_from_entries(fn.transform(fn.col("e"), (x, i) => fn.struct(i, x)))
   }
 
+  functionTest("map_sort") {
+    fn.map_sort(fn.col("f"))
+  }
+
   functionTest("arrays_zip") {
     fn.arrays_zip(fn.col("e"), fn.sequence(lit(1), lit(20)))
   }
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain
new file mode 100644
index 000000000000..069b2ce65d18
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain
@@ -0,0 +1,2 @@
+Project [map_sort(f#0, true) AS map_sort(f, true)#0]
++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json
new file mode 100644
index 000000000000..81a9788d0fba
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json
@@ -0,0 +1,29 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "project": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": 
"struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
+      }
+    },
+    "expressions": [{
+      "unresolvedFunction": {
+        "functionName": "map_sort",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "f"
+          }
+        }, {
+          "literal": {
+            "boolean": true
+          }
+        }]
+      }
+    }]
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin
new file mode 100644
index 000000000000..57b823a57129
Binary files /dev/null and 
b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin
 differ
diff --git a/python/docs/source/reference/pyspark.sql/functions.rst 
b/python/docs/source/reference/pyspark.sql/functions.rst
index e731c319525e..def17dd675ab 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -396,6 +396,7 @@ Map Functions
     map_from_entries
     map_keys
     map_values
+    map_sort
     str_to_map
 
 
diff --git a/python/pyspark/sql/connect/functions/builtin.py 
b/python/pyspark/sql/connect/functions/builtin.py
index c423c5f188ef..370128ede116 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -2004,6 +2004,13 @@ def map_values(col: "ColumnOrName") -> Column:
 map_values.__doc__ = pysparkfuncs.map_values.__doc__
 
 
+def map_sort(col: "ColumnOrName", asc: bool = True) -> Column:
+    return _invoke_function("map_sort", _to_col(col), lit(asc))
+
+
+map_sort.__doc__ = pysparkfuncs.map_sort.__doc__
+
+
 def map_zip_with(
     col1: "ColumnOrName",
     col2: "ColumnOrName",
diff --git a/python/pyspark/sql/functions/builtin.py 
b/python/pyspark/sql/functions/builtin.py
index f9d96778b886..4f53f4a664f1 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -16878,6 +16878,54 @@ def map_concat(
     return _invoke_function_over_seq_of_columns("map_concat", cols)  # type: 
ignore[arg-type]
 
 
+@_try_remote_functions
+def map_sort(col: "ColumnOrName", asc: bool = True) -> Column:
+    """
+    Map function: Sorts the input map in ascending or descending order 
according
+    to the natural ordering of the map keys.
+
+    .. versionadded:: 4.0.0
+
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        Name of the column or expression.
+    asc : bool, optional
+        Whether to sort in ascending or descending order. If `asc` is True 
(default),
+        then the sorting is in ascending order. If False, then in descending 
order.
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        Sorted map.
+
+    Examples
+    --------
+    Example 1: Sorting a map in ascending order
+
+    >>> import pyspark.sql.functions as sf
+    >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data")
+    >>> df.select(sf.map_sort(df.data)).show(truncate=False)
+    +------------------------+
+    |map_sort(data, true)    |
+    +------------------------+
+    |{1 -> a, 2 -> b, 3 -> c}|
+    +------------------------+
+
+    Example 2: Sorting a map in descending order
+
+    >>> import pyspark.sql.functions as sf
+    >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data")
+    >>> df.select(sf.map_sort(df.data, False)).show(truncate=False)
+    +------------------------+
+    |map_sort(data, false)   |
+    +------------------------+
+    |{3 -> c, 2 -> b, 1 -> a}|
+    +------------------------+
+    """
+    return _invoke_function("map_sort", _to_java_column(col), asc)
+
+
 @_try_remote_functions
 def sequence(
     start: "ColumnOrName", stop: "ColumnOrName", step: 
Optional["ColumnOrName"] = None
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index e42fd9fa7bf6..def2cee41a4c 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1445,6 +1445,13 @@ class FunctionsTestsMixin:
             {1: "a", 2: "b", 3: "c"},
         )
 
+    def test_map_sort(self):
+        df = self.spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as map1")
+        self.assertEqual(
+            df.select(F.map_sort("map1").alias("map2")).first()[0],
+            {1: "a", 2: "b", 3: "c"},
+        )
+
     def test_version(self):
         
self.assertIsInstance(self.spark.range(1).select(F.version()).first()[0], str)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index b165d20d0b4f..f64f88cfd9b6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -696,6 +696,7 @@ object FunctionRegistry {
     expression[MapEntries]("map_entries"),
     expression[MapFromEntries]("map_from_entries"),
     expression[MapConcat]("map_concat"),
+    expression[MapSort]("map_sort"),
     expression[Size]("size"),
     expression[Slice]("slice"),
     expression[Size]("cardinality", true),
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index a090bdf2bebf..3ed711d47762 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -888,6 +888,178 @@ case class MapFromEntries(child: Expression)
     copy(child = newChild)
 }
 
+@ExpressionDescription(
+  usage = """
+    _FUNC_(map[, ascendingOrder]) - Sorts the input map in ascending or 
descending order
+      according to the natural ordering of the map keys. The algorithm used 
for sorting is
+      an adaptive, stable and iterative algorithm. If the input map is empty, 
function
+      returns an empty map.
+  """,
+  arguments =
+    """
+    Arguments:
+      * map - The map that will be sorted.
+      * ascendingOrder - A boolean value describing the order in which the map 
will be sorted.
+          This can be either be ascending (true) or descending (false).
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(map(3, 'c', 1, 'a', 2, 'b'), true);
+       {1:"a",2:"b",3:"c"}
+  """,
+  group = "map_funcs",
+  since = "4.0.0")
+case class MapSort(base: Expression, ascendingOrder: Expression)
+  extends BinaryExpression with NullIntolerant with QueryErrorsBase {
+
+  def this(e: Expression) = this(e, Literal(true))
+
+  val keyType: DataType = base.dataType.asInstanceOf[MapType].keyType
+  val valueType: DataType = base.dataType.asInstanceOf[MapType].valueType
+
+  override def left: Expression = base
+  override def right: Expression = ascendingOrder
+  override def dataType: DataType = base.dataType
+
+  override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
+    case m: MapType if RowOrdering.isOrderable(m.keyType) =>
+      ascendingOrder match {
+        case Literal(_: Boolean, BooleanType) =>
+          TypeCheckResult.TypeCheckSuccess
+        case _ =>
+          DataTypeMismatch(
+            errorSubClass = "UNEXPECTED_INPUT_TYPE",
+            messageParameters = Map(
+              "paramIndex" -> ordinalNumber(1),
+              "requiredType" -> toSQLType(BooleanType),
+              "inputSql" -> toSQLExpr(ascendingOrder),
+              "inputType" -> toSQLType(ascendingOrder.dataType))
+          )
+      }
+    case _: MapType =>
+      DataTypeMismatch(
+        errorSubClass = "INVALID_ORDERING_TYPE",
+        messageParameters = Map(
+          "functionName" -> toSQLId(prettyName),
+          "dataType" -> toSQLType(base.dataType)
+        )
+      )
+    case _ =>
+      DataTypeMismatch(
+        errorSubClass = "UNEXPECTED_INPUT_TYPE",
+        messageParameters = Map(
+          "paramIndex" -> ordinalNumber(0),
+          "requiredType" -> toSQLType(MapType),
+          "inputSql" -> toSQLExpr(base),
+          "inputType" -> toSQLType(base.dataType))
+      )
+  }
+
+  override def nullSafeEval(array: Any, ascending: Any): Any = {
+    // put keys and their respective values inside a tuple and sort them
+    // according to the key ordering. Extract the new sorted k/v pairs to form 
a sorted map
+
+    val mapData = array.asInstanceOf[MapData]
+    val numElements = mapData.numElements()
+    val keys = mapData.keyArray()
+    val values = mapData.valueArray()
+
+    val ordering = if (ascending.asInstanceOf[Boolean]) {
+      PhysicalDataType.ordering(keyType)
+    } else {
+      PhysicalDataType.ordering(keyType).reverse
+    }
+
+    val sortedMap = Array
+      .tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any],
+        values.get(i, valueType).asInstanceOf[Any]))
+      .sortBy(_._1)(ordering)
+
+    new ArrayBasedMapData(new GenericArrayData(sortedMap.map(_._1)),
+      new GenericArrayData(sortedMap.map(_._2)))
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order))
+  }
+
+  private def sortCodegen(ctx: CodegenContext, ev: ExprCode,
+      base: String, order: String): String = {
+
+    val arrayBasedMapData = classOf[ArrayBasedMapData].getName
+    val genericArrayData = classOf[GenericArrayData].getName
+
+    val numElements = ctx.freshName("numElements")
+    val keys = ctx.freshName("keys")
+    val values = ctx.freshName("values")
+    val sortArray = ctx.freshName("sortArray")
+    val i = ctx.freshName("i")
+    val o1 = ctx.freshName("o1")
+    val o1entry = ctx.freshName("o1entry")
+    val o2 = ctx.freshName("o2")
+    val o2entry = ctx.freshName("o2entry")
+    val c = ctx.freshName("c")
+    val newKeys = ctx.freshName("newKeys")
+    val newValues = ctx.freshName("newValues")
+
+    val boxedKeyType = CodeGenerator.boxedType(keyType)
+    val boxedValueType = CodeGenerator.boxedType(valueType)
+    val javaKeyType = CodeGenerator.javaType(keyType)
+
+    val simpleEntryType = s"java.util.AbstractMap.SimpleEntry<$boxedKeyType, 
$boxedValueType>"
+
+    val comp = if (CodeGenerator.isPrimitiveType(keyType)) {
+      val v1 = ctx.freshName("v1")
+      val v2 = ctx.freshName("v2")
+      s"""
+         |$javaKeyType $v1 = (($boxedKeyType) $o1).${javaKeyType}Value();
+         |$javaKeyType $v2 = (($boxedKeyType) $o2).${javaKeyType}Value();
+         |int $c = ${ctx.genComp(keyType, v1, v2)};
+       """.stripMargin
+    } else {
+      s"int $c = ${ctx.genComp(keyType, s"(($javaKeyType) $o1)", 
s"(($javaKeyType) $o2)")};"
+    }
+
+    s"""
+       |final int $numElements = $base.numElements();
+       |ArrayData $keys = $base.keyArray();
+       |ArrayData $values = $base.valueArray();
+       |
+       |Object[] $sortArray = new Object[$numElements];
+       |
+       |for (int $i = 0; $i < $numElements; $i++) {
+       |  $sortArray[$i] = new $simpleEntryType(
+       |    ${CodeGenerator.getValue(keys, keyType, i)},
+       |    ${CodeGenerator.getValue(values, valueType, i)});
+       |}
+       |
+       |java.util.Arrays.sort($sortArray, new java.util.Comparator<Object>() {
+       |  @Override public int compare(Object $o1entry, Object $o2entry) {
+       |    Object $o1 = (($simpleEntryType) $o1entry).getKey();
+       |    Object $o2 = (($simpleEntryType) $o2entry).getKey();
+       |    $comp;
+       |    return $order ? $c : -$c;
+       |  }
+       |});
+       |
+       |Object[] $newKeys = new Object[$numElements];
+       |Object[] $newValues = new Object[$numElements];
+       |
+       |for (int $i = 0; $i < $numElements; $i++) {
+       |  $newKeys[$i] = (($simpleEntryType) $sortArray[$i]).getKey();
+       |  $newValues[$i] = (($simpleEntryType) $sortArray[$i]).getValue();
+       |}
+       |
+       |${ev.value} = new $arrayBasedMapData(
+       |  new $genericArrayData($newKeys), new $genericArrayData($newValues));
+       |""".stripMargin
+  }
+
+  override def prettyName: String = "map_sort"
+
+  override protected def withNewChildrenInternal(newLeft: Expression, 
newRight: Expression)
+    : MapSort = copy(base = newLeft, ascendingOrder = newRight)
+}
 
 /**
  * Common base class for [[SortArray]] and [[ArraySort]].
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 133e27c5b0a6..3063b83d4dca 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -421,6 +421,46 @@ class CollectionExpressionsSuite
     )
   }
 
+  test("Sort Map") {
+    val intKey = Literal.create(Map(2 -> 2, 1 -> 1, 3 -> 3), 
MapType(IntegerType, IntegerType))
+    val boolKey = Literal.create(Map(true -> 2, false -> 1), 
MapType(BooleanType, IntegerType))
+    val stringKey = Literal.create(Map("2" -> 2, "1" -> 1, "3" -> 3),
+      MapType(StringType, IntegerType))
+    val arrayKey = Literal.create(Map(Seq(2) -> 2, Seq(1) -> 1, Seq(3) -> 3),
+      MapType(ArrayType(IntegerType), IntegerType))
+    val nestedArrayKey = Literal.create(Map(Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 
1, Seq(Seq(3)) -> 3),
+      MapType(ArrayType(ArrayType(IntegerType)), IntegerType))
+    val structKey = Literal.create(
+      Map(create_row(2) -> 2, create_row(1) -> 1, create_row(3) -> 3),
+      MapType(StructType(Seq(StructField("a", IntegerType))), IntegerType))
+
+    checkEvaluation(new MapSort(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3))
+    checkEvaluation(MapSort(intKey, Literal.create(false, BooleanType)),
+      Map(3 -> 3, 2 -> 2, 1 -> 1))
+
+    checkEvaluation(new MapSort(boolKey), Map(false -> 1, true -> 2))
+    checkEvaluation(MapSort(boolKey, Literal.create(false, BooleanType)),
+      Map(true -> 2, false -> 1))
+
+    checkEvaluation(new MapSort(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3))
+    checkEvaluation(MapSort(stringKey, Literal.create(false, BooleanType)),
+      Map("3" -> 3, "2" -> 2, "1" -> 1))
+
+    checkEvaluation(new MapSort(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, 
Seq(3) -> 3))
+    checkEvaluation(MapSort(arrayKey, Literal.create(false, BooleanType)),
+      Map(Seq(3) -> 3, Seq(2) -> 2, Seq(1) -> 1))
+
+    checkEvaluation(new MapSort(nestedArrayKey),
+      Map(Seq(Seq(1)) -> 1, Seq(Seq(2)) -> 2, Seq(Seq(3)) -> 3))
+    checkEvaluation(MapSort(nestedArrayKey, Literal.create(false, 
BooleanType)),
+      Map(Seq(Seq(3)) -> 3, Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 1))
+
+    checkEvaluation(new MapSort(structKey),
+      Map(create_row(1) -> 1, create_row(2) -> 2, create_row(3) -> 3))
+    checkEvaluation(MapSort(structKey, Literal.create(false, BooleanType)),
+      Map(create_row(3) -> 3, create_row(2) -> 2, create_row(1) -> 1))
+  }
+
   test("Sort Array") {
     val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
     val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index d2dc3a326389..d589092070ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -7000,6 +7000,23 @@ object functions {
   @scala.annotation.varargs
   def map_concat(cols: Column*): Column = Column.fn("map_concat", cols: _*)
 
+  /**
+   * Sorts the input map in ascending order based on the natural order of map 
keys.
+   *
+   * @group map_funcs
+   * @since 4.0.0
+   */
+  def map_sort(e: Column): Column = map_sort(e, asc = true)
+
+  /**
+   * Sorts the input map in ascending or descending order according to the 
natural ordering
+   * of the map keys.
+   *
+   * @group map_funcs
+   * @since 4.0.0
+   */
+  def map_sort(e: Column, asc: Boolean): Column = Column.fn("map_sort", e, 
lit(asc))
+
   // scalastyle:off line.size.limit
   /**
    * Parses a column containing a CSV string into a `StructType` with the 
specified schema.
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md 
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index bd1b6f0cb753..3e2b7867ef3c 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -215,6 +215,7 @@
 | org.apache.spark.sql.catalyst.expressions.MapFromArrays | map_from_arrays | 
SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')) | 
struct<map_from_arrays(array(1.0, 3.0), array(2, 4)):map<decimal(2,1),string>> |
 | org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries 
| SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | 
struct<map_from_entries(array(struct(1, a), struct(2, b))):map<int,string>> |
 | org.apache.spark.sql.catalyst.expressions.MapKeys | map_keys | SELECT 
map_keys(map(1, 'a', 2, 'b')) | struct<map_keys(map(1, a, 2, b)):array<int>> |
+| org.apache.spark.sql.catalyst.expressions.MapSort | map_sort | SELECT 
map_sort(map(3, 'c', 1, 'a', 2, 'b'), true) | struct<map_sort(map(3, c, 1, a, 
2, b), true):map<int,string>> |
 | org.apache.spark.sql.catalyst.expressions.MapValues | map_values | SELECT 
map_values(map(1, 'a', 2, 'b')) | struct<map_values(map(1, a, 2, 
b)):array<string>> |
 | org.apache.spark.sql.catalyst.expressions.MapZipWith | map_zip_with | SELECT 
map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> 
concat(v1, v2)) | struct<map_zip_with(map(1, a, 2, b), map(1, x, 2, y), 
lambdafunction(concat(namedlambdavariable(), namedlambdavariable()), 
namedlambdavariable(), namedlambdavariable(), 
namedlambdavariable())):map<int,string>> |
 | org.apache.spark.sql.catalyst.expressions.MaskExpressionBuilder | mask | 
SELECT mask('abcd-EFGH-8765-4321') | struct<mask(abcd-EFGH-8765-4321, X, x, n, 
NULL):string> |
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index e42f397cbfc2..e5953e59a51b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -25,7 +25,7 @@ import java.sql.{Date, Timestamp}
 import scala.util.Random
 
 import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkRuntimeException}
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow}
 import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, 
UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.expressions.{Alias, ArraysZip, 
AttributeReference, Expression, NamedExpression, UnaryExpression}
 import org.apache.spark.sql.catalyst.expressions.Cast._
@@ -780,6 +780,86 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
     )
   }
 
+  test("map_sort function") {
+    val df1 = Seq(
+      Map[Int, Int](2 -> 2, 1 -> 1, 3 -> 3)
+    ).toDF("a")
+
+    checkAnswer(
+      df1.selectExpr("map_sort(a)"),
+      Seq(
+        Row(Map(1 -> 1, 2 -> 2, 3 -> 3))
+      )
+    )
+    checkAnswer(
+      df1.selectExpr("map_sort(a, true)"),
+      Seq(
+        Row(Map(1 -> 1, 2 -> 2, 3 -> 3))
+      )
+    )
+    checkAnswer(
+      df1.select(map_sort($"a", asc = false)),
+      Seq(
+        Row(Map(3 -> 3, 2 -> 2, 1 -> 1))
+      )
+    )
+
+    val df2 = Seq(Map.empty[Int, Int]).toDF("a")
+
+    checkAnswer(
+      df2.selectExpr("map_sort(a, true)"),
+      Seq(Row(Map()))
+    )
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        df2.orderBy("a")
+      },
+      errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE",
+      parameters = Map(
+        "functionName" -> "`sortorder`",
+        "dataType" -> "\"MAP<INT, INT>\"",
+        "sqlExpr" -> "\"a ASC NULLS FIRST\"")
+    )
+
+    checkError(
+      exception = intercept[SparkRuntimeException] {
+        sql("SELECT map_sort(map(null, 1))").collect()
+      },
+      errorClass = "NULL_MAP_KEY"
+    )
+
+    checkError(
+      exception = intercept[ExtendedAnalysisException] {
+        sql("SELECT map_sort(map(1,1,2,2), \"asc\")").collect()
+      },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"",
+        "paramIndex" -> "second",
+        "inputSql" -> "\"asc\"",
+        "inputType" -> "\"STRING\"",
+        "requiredType" -> "\"BOOLEAN\""
+      ),
+      queryContext = Array(ExpectedContext("", "", 7, 35, 
"map_sort(map(1,1,2,2), \"asc\")"))
+    )
+
+    checkError(
+      exception = intercept[ExtendedAnalysisException] {
+        sql("SELECT map_sort(map(1,1,2,2), \"asc\")").collect()
+      },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"",
+        "paramIndex" -> "second",
+        "inputSql" -> "\"asc\"",
+        "inputType" -> "\"STRING\"",
+        "requiredType" -> "\"BOOLEAN\""
+      ),
+      queryContext = Array(ExpectedContext("", "", 7, 35, 
"map_sort(map(1,1,2,2), \"asc\")"))
+    )
+  }
+
   test("sort_array/array_sort functions") {
     val df = Seq(
       (Array[Int](2, 1, 3), Array("b", "c", "a")),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to