This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1e26abf824b [SPARK-39925][SQL] Add array_sort(column, comparator) overload to DataFrame operations 1e26abf824b is described below commit 1e26abf824b7a1bb4e13cb5233d8f1d81388a095 Author: Brandon Dahler <b...@amazon.com> AuthorDate: Sun Aug 21 18:21:13 2022 +0900 [SPARK-39925][SQL] Add array_sort(column, comparator) overload to DataFrame operations ### What changes were proposed in this pull request? Adding a new `array_sort` overload to `org.apache.spark.sql.functions` that matches the new overload defined in [SPARK-29020](https://issues.apache.org/jira/browse/SPARK-29020) and added via #25728. ### Why are the changes needed? Adds access to the new overload for users of the DataFrame API so that they don't need to use the `expr` escape hatch. ### Does this PR introduce _any_ user-facing change? Yes, now allows users to optionally provide a comparator function to the `array_sort`, which opens up the ability to sort descending as well as sort items that aren't naturally orderable. #### Example: Old: ``` df.selectExpr("array_sort(a, (x, y) -> cardinality(x) - cardinality(y))"); ``` Added: ``` df.select(array_sort(col("a"), (x, y) => size(x) - size(y))); ``` ### How was this patch tested? Unit tests updated to validate that the overload matches the expression's behavior. Closes #37361 from brandondahler/features/ArraySortOverload. Authored-by: Brandon Dahler <b...@amazon.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../scala/org/apache/spark/sql/functions.scala | 13 +++++++ .../apache/spark/sql/DataFrameFunctionsSuite.scala | 42 ++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 533c5614885..bd7473706ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3968,6 +3968,19 @@ object functions { */ def array_sort(e: Column): Column = withExpr { new ArraySort(e.expr) } + /** + * Sorts the input array based on the given comparator function. The comparator will take two + * arguments representing two elements of the array. It returns a negative integer, 0, or a + * positive integer as the first element is less than, equal to, or greater than the second + * element. If the comparator function returns null, the function will fail and raise an error. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_sort(e: Column, comparator: (Column, Column) => Column): Column = withExpr { + new ArraySort(e.expr, createLambda(comparator)) + } + /** * Remove all elements that equal to element from the given array. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9924fbfbf62..b80925f8638 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -434,6 +434,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { }) val df1 = Seq(Array[Int](3, 2, 5, 1, 2)).toDF("a") + checkAnswer( + df1.select(array_sort(col("a"), (x, y) => call_udf("fAsc", x, y))), + Seq( + Row(Seq(1, 2, 2, 3, 5))) + ) + + checkAnswer( + df1.select(array_sort(col("a"), (x, y) => call_udf("fDesc", x, y))), + Seq( + Row(Seq(5, 3, 2, 2, 1))) + ) + checkAnswer( df1.selectExpr("array_sort(a, (x, y) -> fAsc(x, y))"), Seq( @@ -447,6 +459,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) val df2 = Seq(Array[String]("bc", "ab", "dc")).toDF("a") + checkAnswer( + df2.select(array_sort(col("a"), (x, y) => call_udf("fString", x, y))), + Seq( + Row(Seq("dc", "bc", "ab"))) + ) + checkAnswer( df2.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), Seq( @@ -454,6 +472,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) val df3 = Seq(Array[String]("a", "abcd", "abc")).toDF("a") + checkAnswer( + df3.select(array_sort(col("a"), (x, y) => call_udf("fStringLength", x, y))), + Seq( + Row(Seq("a", "abc", "abcd"))) + ) + checkAnswer( df3.selectExpr("array_sort(a, (x, y) -> fStringLength(x, y))"), Seq( @@ -462,6 +486,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val df4 = Seq((Array[Array[Int]](Array(2, 3, 1), Array(4, 2, 1, 4), Array(1, 2)), "x")).toDF("a", "b") + checkAnswer( + df4.select(array_sort(col("a"), (x, y) => call_udf("fAsc", size(x), size(y)))), + Seq( + Row(Seq[Seq[Int]](Seq(1, 2), Seq(2, 3, 1), Seq(4, 2, 1, 4)))) + ) + checkAnswer( df4.selectExpr("array_sort(a, (x, y) -> fAsc(cardinality(x), cardinality(y)))"), Seq( @@ -469,6 +499,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) val df5 = Seq(Array[String]("bc", null, "ab", "dc")).toDF("a") + checkAnswer( + df5.select(array_sort(col("a"), (x, y) => call_udf("fString", x, y))), + Seq( + Row(Seq("dc", "bc", "ab", null))) + ) + checkAnswer( df5.selectExpr("array_sort(a, (x, y) -> fString(x, y))"), Seq( @@ -484,6 +520,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("SPARK-38130: array_sort with lambda of non-orderable items") { val df6 = Seq((Array[Map[String, Int]](Map("a" -> 1), Map("b" -> 2, "c" -> 3), Map()), "x")).toDF("a", "b") + checkAnswer( + df6.select(array_sort(col("a"), (x, y) => size(x) - size(y))), + Seq( + Row(Seq[Map[String, Int]](Map(), Map("a" -> 1), Map("b" -> 2, "c" -> 3)))) + ) + checkAnswer( df6.selectExpr("array_sort(a, (x, y) -> cardinality(x) - cardinality(y))"), Seq( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org