This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 295e98d29b3 [SPARK-40214][PYTHON][SQL] add 'get' to functions 295e98d29b3 is described below commit 295e98d29b34e2b472c375608b8782c3b9189444 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Aug 25 14:44:18 2022 +0800 [SPARK-40214][PYTHON][SQL] add 'get' to functions ### What changes were proposed in this pull request? expose `get` to dataframe functions ### Why are the changes needed? for function parity ### Does this PR introduce _any_ user-facing change? yes, new API ### How was this patch tested? added UT Closes #37652 from zhengruifeng/py_get. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../source/reference/pyspark.sql/functions.rst | 1 + python/pyspark/sql/functions.py | 70 ++++++++++++++++++++++ .../scala/org/apache/spark/sql/functions.scala | 11 ++++ .../apache/spark/sql/DataFrameFunctionsSuite.scala | 38 ++++++++++++ 4 files changed, 120 insertions(+) diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index a799bb8ad0a..027babbf57d 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -176,6 +176,7 @@ Collection Functions explode_outer posexplode posexplode_outer + get get_json_object json_tuple from_json diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d59532f52cb..fd7a7247fc8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -4832,6 +4832,10 @@ def element_at(col: "ColumnOrName", extraction: Any) -> Column: ----- The position is not zero based, but 1 based index. + See Also + -------- + :meth:`get` + Examples -------- >>> df = spark.createDataFrame([(["a", "b", "c"],)], ['data']) @@ -4845,6 +4849,72 @@ def element_at(col: "ColumnOrName", extraction: Any) -> Column: return _invoke_function_over_columns("element_at", col, lit(extraction)) +def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: + """ + Collection function: Returns element of array at given (0-based) index. + If the index points outside of the array boundaries, then this function + returns NULL. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + index : :class:`~pyspark.sql.Column` or str or int + index to check for in array + + Notes + ----- + The position is not 1 based, but 0 based index. + + See Also + -------- + :meth:`element_at` + + Examples + -------- + >>> df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index']) + >>> df.select(get(df.data, 1)).show() + +------------+ + |get(data, 1)| + +------------+ + | b| + +------------+ + + >>> df.select(get(df.data, -1)).show() + +-------------+ + |get(data, -1)| + +-------------+ + | null| + +-------------+ + + >>> df.select(get(df.data, 3)).show() + +------------+ + |get(data, 3)| + +------------+ + | null| + +------------+ + + >>> df.select(get(df.data, "index")).show() + +----------------+ + |get(data, index)| + +----------------+ + | b| + +----------------+ + + >>> df.select(get(df.data, col("index") - 1)).show() + +----------------------+ + |get(data, (index - 1))| + +----------------------+ + | a| + +----------------------+ + """ + index = lit(index) if isinstance(index, int) else index + + return _invoke_function_over_columns("get", col, index) + + def array_remove(col: "ColumnOrName", element: Any) -> Column: """ Collection function: Remove all elements that equal to element from the given array. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bd7473706ca..69da277d5e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3958,6 +3958,17 @@ object functions { ElementAt(column.expr, lit(value).expr) } + /** + * Returns element of array at given (0-based) index. If the index points + * outside of the array boundaries, then this function returns NULL. + * + * @group collection_funcs + * @since 3.4.0 + */ + def get(column: Column, index: Column): Column = withExpr { + new Get(column.expr, index.expr) + } + /** * Sorts the input array in ascending order. The elements of the input array must be orderable. * NaN is greater than any non-NaN elements for double/float type. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b80925f8638..ee41b1efba2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1628,6 +1628,44 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { assert(e3.message.contains(errorMsg3)) } + test("SPARK-40214: get function") { + val df = Seq( + (Seq[String]("1", "2", "3"), 2), + (Seq[String](null, ""), 1), + (Seq[String](), 2), + (null, 3) + ).toDF("a", "b") + + checkAnswer( + df.select(get(df("a"), lit(-1))), + Seq(Row(null), Row(null), Row(null), Row(null)) + ) + checkAnswer( + df.select(get(df("a"), lit(0))), + Seq(Row("1"), Row(null), Row(null), Row(null)) + ) + checkAnswer( + df.select(get(df("a"), lit(1))), + Seq(Row("2"), Row(""), Row(null), Row(null)) + ) + checkAnswer( + df.select(get(df("a"), lit(2))), + Seq(Row("3"), Row(null), Row(null), Row(null)) + ) + checkAnswer( + df.select(get(df("a"), lit(3))), + Seq(Row(null), Row(null), Row(null), Row(null)) + ) + checkAnswer( + df.select(get(df("a"), df("b"))), + Seq(Row("3"), Row(""), Row(null), Row(null)) + ) + checkAnswer( + df.select(get(df("a"), df("b") - 1)), + Seq(Row("2"), Row(null), Row(null), Row(null)) + ) + } + test("array_union functions") { val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b") val ans1 = Row(Seq(1, 2, 3, 4)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org