This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d53ba6276f4 [SPARK-41436][CONNECT][PYTHON] Implement `collection` functions: A~C d53ba6276f4 is described below commit d53ba6276f407b6e090c9610b85a56d047e56f73 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Dec 8 16:59:13 2022 +0800 [SPARK-41436][CONNECT][PYTHON] Implement `collection` functions: A~C ### What changes were proposed in this pull request? Implement [`collection` functions](https://github.com/apache/spark/blob/master/python/docs/source/reference/pyspark.sql/functions.rst#collection-functions) alphabetically, this PR contains `A` ~ `C` except: - `aggregate`, `array_sort` - need the support of LambdaFunction Expression - the `int count` in `array_repeat` - need to support datatype in LiteralExpression in the Python Client ### Why are the changes needed? For API coverage ### Does this PR introduce _any_ user-facing change? new APIs ### How was this patch tested? added UT Closes #38961 from zhengruifeng/connect_function_collect_1. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/functions.py | 608 ++++++++++++++++++++- .../sql/tests/connect/test_connect_function.py | 156 ++++++ 2 files changed, 763 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 8b36647ae5b..7eb17bd89ac 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -90,7 +90,10 @@ column = col def lit(col: Any) -> Column: - return Column(LiteralExpression(col)) + if isinstance(col, Column): + return col + else: + return Column(LiteralExpression(col)) # def bitwiseNOT(col: "ColumnOrName") -> Column: @@ -3208,6 +3211,609 @@ def variance(col: "ColumnOrName") -> Column: return var_samp(col) +# Collection Functions + + +# TODO(SPARK-41434): need to support LambdaFunction Expression first +# def aggregate( +# col: "ColumnOrName", +# initialValue: "ColumnOrName", +# merge: Callable[[Column, Column], Column], +# finish: Optional[Callable[[Column], Column]] = None, +# ) -> Column: +# """ +# Applies a binary operator to an initial state and all elements in the array, +# and reduces this to a single state. The final state is converted into the final result +# by applying a finish function. +# +# Both functions can use methods of :class:`~pyspark.sql.Column`, functions defined in +# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``. +# Python ``UserDefinedFunctions`` are not supported +# (`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__). +# +# .. versionadded:: 3.1.0 +# +# Parameters +# ---------- +# col : :class:`~pyspark.sql.Column` or str +# name of column or expression +# initialValue : :class:`~pyspark.sql.Column` or str +# initial value. Name of column or expression +# merge : function +# a binary function ``(acc: Column, x: Column) -> Column...`` returning expression +# of the same type as ``zero`` +# finish : function +# an optional unary function ``(x: Column) -> Column: ...`` +# used to convert accumulated value. +# +# Returns +# ------- +# :class:`~pyspark.sql.Column` +# final value after aggregate function is applied. +# +# Examples +# -------- +# >>> df = spark.createDataFrame([(1, [20.0, 4.0, 2.0, 6.0, 10.0])], ("id", "values")) +# >>> df.select(aggregate("values", lit(0.0), lambda acc, x: acc + x).alias("sum")).show() +# +----+ +# | sum| +# +----+ +# |42.0| +# +----+ +# +# >>> def merge(acc, x): +# ... count = acc.count + 1 +# ... sum = acc.sum + x +# ... return struct(count.alias("count"), sum.alias("sum")) +# >>> df.select( +# ... aggregate( +# ... "values", +# ... struct(lit(0).alias("count"), lit(0.0).alias("sum")), +# ... merge, +# ... lambda acc: acc.sum / acc.count, +# ... ).alias("mean") +# ... ).show() +# +----+ +# |mean| +# +----+ +# | 8.4| +# +----+ +# """ +# if finish is not None: +# return _invoke_higher_order_function("ArrayAggregate", [col, initialValue], +# [merge, finish]) +# +# else: +# return _invoke_higher_order_function("ArrayAggregate", [col, initialValue], +# [merge]) + + +def array(*cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]) -> Column: + """Creates a new array column. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + cols : :class:`~pyspark.sql.Column` or str + column names or :class:`~pyspark.sql.Column`\\s that have + the same data type. + + Returns + ------- + :class:`~pyspark.sql.Column` + a column of array type. + + Examples + -------- + >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age")) + >>> df.select(array('age', 'age').alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + >>> df.select(array([df.age, df.age]).alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + >>> df.select(array('age', 'age').alias("col")).printSchema() + root + |-- col: array (nullable = false) + | |-- element: long (containsNull = true) + """ + if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)): + cols = cols[0] # type: ignore[assignment] + return _invoke_function_over_columns("array", *cols) # type: ignore[arg-type] + + +def array_contains(col: "ColumnOrName", value: Any) -> Column: + """ + Collection function: returns null if the array is null, true if the array contains the + given value, and false otherwise. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + value : + value or column to check for in array + + Returns + ------- + :class:`~pyspark.sql.Column` + a column of Boolean type. + + Examples + -------- + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(array_contains(df.data, "a")).collect() + [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] + >>> df.select(array_contains(df.data, lit("a"))).collect() + [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] + """ + return _invoke_function("array_contains", _to_col(col), lit(value)) + + +def array_distinct(col: "ColumnOrName") -> Column: + """ + Collection function: removes duplicate values from the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + + Returns + ------- + :class:`~pyspark.sql.Column` + an array of unique values. + + Examples + -------- + >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) + >>> df.select(array_distinct(df.data)).collect() + [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] + """ + return _invoke_function_over_columns("array_distinct", col) + + +def array_except(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: + """ + Collection function: returns an array of the elements in col1 but not in col2, + without duplicates. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col1 : :class:`~pyspark.sql.Column` or str + name of column containing array + col2 : :class:`~pyspark.sql.Column` or str + name of column containing array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array of values from first array that are not in the second. + + Examples + -------- + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_except(df.c1, df.c2)).collect() + [Row(array_except(c1, c2)=['b'])] + """ + return _invoke_function_over_columns("array_except", col1, col2) + + +def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: + """ + Collection function: returns an array of the elements in the intersection of col1 and col2, + without duplicates. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col1 : :class:`~pyspark.sql.Column` or str + name of column containing array + col2 : :class:`~pyspark.sql.Column` or str + name of column containing array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array of values in the intersection of two arrays. + + Examples + -------- + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_intersect(df.c1, df.c2)).collect() + [Row(array_intersect(c1, c2)=['a', 'c'])] + """ + return _invoke_function_over_columns("array_intersect", col1, col2) + + +def array_join( + col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None +) -> Column: + """ + Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + `null_replacement` if set, otherwise they are ignored. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + delimiter : str + delimiter used to concatenate elements + null_replacement : str, optional + if set then null values will be replaced by this value + + Returns + ------- + :class:`~pyspark.sql.Column` + a column of string type. Concatenated values. + + Examples + -------- + >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + >>> df.select(array_join(df.data, ",").alias("joined")).collect() + [Row(joined='a,b,c'), Row(joined='a')] + >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() + [Row(joined='a,b,c'), Row(joined='a,NULL')] + """ + if null_replacement is None: + return _invoke_function("array_join", _to_col(col), lit(delimiter)) + else: + return _invoke_function("array_join", _to_col(col), lit(delimiter), lit(null_replacement)) + + +def array_max(col: "ColumnOrName") -> Column: + """ + Collection function: returns the maximum value of the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + + Returns + ------- + :class:`~pyspark.sql.Column` + maximum value of an array. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_max(df.data).alias('max')).collect() + [Row(max=3), Row(max=10)] + """ + return _invoke_function_over_columns("array_max", col) + + +def array_min(col: "ColumnOrName") -> Column: + """ + Collection function: returns the minimum value of the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + + Returns + ------- + :class:`~pyspark.sql.Column` + minimum value of array. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_min(df.data).alias('min')).collect() + [Row(min=1), Row(min=-1)] + """ + return _invoke_function_over_columns("array_min", col) + + +def array_position(col: "ColumnOrName", value: Any) -> Column: + """ + Collection function: Locates the position of the first occurrence of the given value + in the given array. Returns null if either of the arguments are null. + + .. versionadded:: 3.4.0 + + Notes + ----- + The position is not zero based, but 1 based index. Returns 0 if the given + value could not be found in the array. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + value : Any + value to look for. + + Returns + ------- + :class:`~pyspark.sql.Column` + position of the value in the given array if found and 0 otherwise. + + Examples + -------- + >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + >>> df.select(array_position(df.data, "a")).collect() + [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] + """ + return _invoke_function("array_position", _to_col(col), lit(value)) + + +def array_remove(col: "ColumnOrName", element: Any) -> Column: + """ + Collection function: Remove all elements that equal to element from the given array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + element : + element to be removed from the array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array excluding given value. + + Examples + -------- + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + >>> df.select(array_remove(df.data, 1)).collect() + [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])] + """ + return _invoke_function("array_remove", _to_col(col), lit(element)) + + +def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Column: + """ + Collection function: creates an array containing a column repeated count times. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + column name or column that contains the element to be repeated + count : :class:`~pyspark.sql.Column` or str or int + column name, column, or int containing the number of times to repeat the first argument + + Returns + ------- + :class:`~pyspark.sql.Column` + an array of repeated elements. + + Examples + -------- + >>> df = spark.createDataFrame([('ab',)], ['data']) + >>> df.select(array_repeat(df.data, 3).alias('r')).collect() + [Row(r=['ab', 'ab', 'ab'])] + """ + _count = lit(count) if isinstance(count, int) else _to_col(count) + + return _invoke_function("array_repeat", _to_col(col), _count) + + +# TODO(SPARK-41434): need to support LambdaFunction Expression first +# def array_sort( +# col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None +# ) -> Column: +# """ +# Collection function: sorts the input array in ascending order. The elements of the input array +# must be orderable. Null elements will be placed at the end of the returned array. +# +# .. versionadded:: 2.4.0 +# .. versionchanged:: 3.4.0 +# Can take a `comparator` function. +# +# Parameters +# ---------- +# col : :class:`~pyspark.sql.Column` or str +# name of column or expression +# comparator : callable, optional +# A binary ``(Column, Column) -> Column: ...``. +# The comparator will take two +# arguments representing two elements of the array. It returns a negative integer, 0, or a +# positive integer as the first element is less than, equal to, or greater than the second +# element. If the comparator function returns null, the function will fail and raise an +# error. +# +# Returns +# ------- +# :class:`~pyspark.sql.Column` +# sorted array. +# +# Examples +# -------- +# >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) +# >>> df.select(array_sort(df.data).alias('r')).collect() +# [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] +# >>> df = spark.createDataFrame([(["foo", "foobar", None, "bar"],),(["foo"],),([],)], ['data']) +# >>> df.select(array_sort( +# ... "data", +# ... lambda x, y: when(x.isNull() | y.isNull(), lit(0)).otherwise(length(y) - length(x)) +# ... ).alias("r")).collect() +# [Row(r=['foobar', 'foo', None, 'bar']), Row(r=['foo']), Row(r=[])] +# """ +# if comparator is None: +# return _invoke_function_over_columns("array_sort", col) +# else: +# return _invoke_higher_order_function("ArraySort", [col], [comparator]) + + +def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: + """ + Collection function: returns an array of the elements in the union of col1 and col2, + without duplicates. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col1 : :class:`~pyspark.sql.Column` or str + name of column containing array + col2 : :class:`~pyspark.sql.Column` or str + name of column containing array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array of values in union of two arrays. + + Examples + -------- + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_union(df.c1, df.c2)).collect() + [Row(array_union(c1, c2)=['b', 'a', 'c', 'd', 'f'])] + """ + return _invoke_function_over_columns("array_union", col1, col2) + + +def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: + """ + Collection function: returns true if the arrays contain any common non-null element; if not, + returns null if both the arrays are non-empty and any of them contains a null element; returns + false otherwise. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`~pyspark.sql.Column` + a column of Boolean type. + + Examples + -------- + >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() + [Row(overlap=True), Row(overlap=False)] + """ + return _invoke_function_over_columns("arrays_overlap", a1, a2) + + +def arrays_zip(*cols: "ColumnOrName") -> Column: + """ + Collection function: Returns a merged array of structs in which the N-th struct contains all + N-th values of input arrays. If one of the arrays is shorter than others then + resulting struct type value will be a `null` for missing elements. + + .. versionadded:: 2.4.0 + + Parameters + ---------- + cols : :class:`~pyspark.sql.Column` or str + columns of arrays to be merged. + + Returns + ------- + :class:`~pyspark.sql.Column` + merged array of entries. + + Examples + -------- + >>> from pyspark.sql.functions import arrays_zip + >>> df = spark.createDataFrame([(([1, 2, 3], [2, 4, 6], [3, 6]))], ['vals1', 'vals2', 'vals3']) + >>> df = df.select(arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')) + >>> df.show(truncate=False) + +------------------------------------+ + |zipped | + +------------------------------------+ + |[{1, 2, 3}, {2, 4, 6}, {3, 6, null}]| + +------------------------------------+ + >>> df.printSchema() + root + |-- zipped: array (nullable = true) + | |-- element: struct (containsNull = false) + | | |-- vals1: long (nullable = true) + | | |-- vals2: long (nullable = true) + | | |-- vals3: long (nullable = true) + """ + return _invoke_function_over_columns("arrays_zip", *cols) + + +def concat(*cols: "ColumnOrName") -> Column: + """ + Concatenates multiple input columns together into a single column. + The function works with strings, numeric, binary and compatible array columns. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + cols : :class:`~pyspark.sql.Column` or str + target column or columns to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + concatenated values. Type of the `Column` depends on input columns' type. + + See Also + -------- + :meth:`pyspark.sql.functions.array_join` : to concatenate string columns with delimiter + + Examples + -------- + >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df = df.select(concat(df.s, df.d).alias('s')) + >>> df.collect() + [Row(s='abcd123')] + >>> df + DataFrame[s: string] + + >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) + >>> df = df.select(concat(df.a, df.b, df.c).alias("arr")) + >>> df.collect() + [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] + >>> df + DataFrame[arr: array<bigint>] + """ + return _invoke_function_over_columns("concat", *cols) + + +def create_map( + *cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]] +) -> Column: + """Creates a new map column. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + cols : :class:`~pyspark.sql.Column` or str + column names or :class:`~pyspark.sql.Column`\\s that are + grouped as key-value pairs, e.g. (key1, value1, key2, value2, ...). + + Examples + -------- + >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age")) + >>> df.select(create_map('name', 'age').alias("map")).collect() + [Row(map={'Alice': 2}), Row(map={'Bob': 5})] + >>> df.select(create_map([df.name, df.age]).alias("map")).collect() + [Row(map={'Alice': 2}), Row(map={'Bob': 5})] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)): + cols = cols[0] # type: ignore[assignment] + return _invoke_function_over_columns("map", *cols) # type: ignore[arg-type] + + # String/Binary functions diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index ee3a927708e..96801c58ca9 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -63,6 +63,24 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase): """These test cases exercise the interface to the proto plan generation but do not call Spark.""" + def compare_by_show(self, df1: Any, df2: Any): + from pyspark.sql.dataframe import DataFrame as SDF + from pyspark.sql.connect.dataframe import DataFrame as CDF + + assert isinstance(df1, (SDF, CDF)) + if isinstance(df1, SDF): + str1 = df1._jdf.showString(20, 20, False) + else: + str1 = df1._show_string(20, 20, False) + + assert isinstance(df2, (SDF, CDF)) + if isinstance(df2, SDF): + str2 = df2._jdf.showString(20, 20, False) + else: + str2 = df2._show_string(20, 20, False) + + self.assertEqual(str1, str2) + def test_normal_functions(self): from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF @@ -428,6 +446,144 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase): .toPandas(), ) + def test_collection_functions(self): + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + query = """ + SELECT * FROM VALUES + (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3), 1, 2, 'a'), + (ARRAY('x', NULL), NULL, ARRAY(1, 3), 3, 4, 'x'), + (NULL, ARRAY(-1, -2, -3), Array(), 5, 6, NULL) + AS tab(a, b, c, d, e, f) + """ + # +---------+------------+------------+---+---+----+ + # | a| b| c| d| e| f| + # +---------+------------+------------+---+---+----+ + # | [a, ab]| [1, 2, 3]|[1, null, 3]| 1| 2| a| + # |[x, null]| null| [1, 3]| 3| 4| x| + # | null|[-1, -2, -3]| []| 5| 6|null| + # +---------+------------+------------+---+---+----+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + for cfunc, sfunc in [ + (CF.array_distinct, SF.array_distinct), + (CF.array_max, SF.array_max), + (CF.array_min, SF.array_min), + ]: + self.assert_eq( + cdf.select(cfunc("a"), cfunc(cdf.b)).toPandas(), + sdf.select(sfunc("a"), sfunc(sdf.b)).toPandas(), + ) + + for cfunc, sfunc in [ + (CF.array_except, SF.array_except), + (CF.array_intersect, SF.array_intersect), + (CF.array_union, SF.array_union), + (CF.arrays_overlap, SF.arrays_overlap), + ]: + self.assert_eq( + cdf.select(cfunc("b", cdf.c)).toPandas(), + sdf.select(sfunc("b", sdf.c)).toPandas(), + ) + + for cfunc, sfunc in [ + (CF.array_position, SF.array_position), + (CF.array_remove, SF.array_remove), + ]: + self.assert_eq( + cdf.select(cfunc(cdf.a, "ab")).toPandas(), + sdf.select(sfunc(sdf.a, "ab")).toPandas(), + ) + + # test array + self.assert_eq( + cdf.select(CF.array(cdf.d, "e")).toPandas(), + sdf.select(SF.array(sdf.d, "e")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.array(cdf.d, "e", CF.lit(99))).toPandas(), + sdf.select(SF.array(sdf.d, "e", SF.lit(99))).toPandas(), + ) + + # test array_contains + self.assert_eq( + cdf.select(CF.array_contains(cdf.a, "ab")).toPandas(), + sdf.select(SF.array_contains(sdf.a, "ab")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.array_contains(cdf.a, cdf.f)).toPandas(), + sdf.select(SF.array_contains(sdf.a, sdf.f)).toPandas(), + ) + + # test array_join + self.assert_eq( + cdf.select( + CF.array_join(cdf.a, ","), CF.array_join("b", ":"), CF.array_join("c", "~") + ).toPandas(), + sdf.select( + SF.array_join(sdf.a, ","), SF.array_join("b", ":"), SF.array_join("c", "~") + ).toPandas(), + ) + self.assert_eq( + cdf.select( + CF.array_join(cdf.a, ",", "_null_"), + CF.array_join("b", ":", ".null."), + CF.array_join("c", "~", "NULL"), + ).toPandas(), + sdf.select( + SF.array_join(sdf.a, ",", "_null_"), + SF.array_join("b", ":", ".null."), + SF.array_join("c", "~", "NULL"), + ).toPandas(), + ) + + # test array_repeat + self.assert_eq( + cdf.select(CF.array_repeat(cdf.f, "d")).toPandas(), + sdf.select(SF.array_repeat(sdf.f, "d")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.array_repeat("f", cdf.d)).toPandas(), + sdf.select(SF.array_repeat("f", sdf.d)).toPandas(), + ) + # TODO: Make Literal contains DataType + # Cannot resolve "array_repeat(f, 3)" due to data type mismatch: + # Parameter 2 requires the "INT" type, however "3" has the type "BIGINT". + # self.assert_eq( + # cdf.select(CF.array_repeat("f", 3)).toPandas(), + # sdf.select(SF.array_repeat("f", 3)).toPandas(), + # ) + + # test arrays_zip + # TODO: Make toPandas support complex nested types like Array<Struct> + # DataFrame.iloc[:, 0] (column name="arrays_zip(b, c)") values are different (66.66667 %) + # [index]: [0, 1, 2] + # [left]: [[{'b': 1, 'c': 1.0}, {'b': 2, 'c': None}, {'b': 3, 'c': 3.0}], None, + # [{'b': -1, 'c': None}, {'b': -2, 'c': None}, {'b': -3, 'c': None}]] + # [right]: [[(1, 1), (2, None), (3, 3)], None, [(-1, None), (-2, None), (-3, None)]] + self.compare_by_show( + cdf.select(CF.arrays_zip(cdf.b, "c")), + sdf.select(SF.arrays_zip(sdf.b, "c")), + ) + + # test concat + self.assert_eq( + cdf.select(CF.concat("d", cdf.e, CF.lit(-1))).toPandas(), + sdf.select(SF.concat("d", sdf.e, SF.lit(-1))).toPandas(), + ) + + # test create_map + self.compare_by_show( + cdf.select(CF.create_map(cdf.d, cdf.e)), sdf.select(SF.create_map(sdf.d, sdf.e)) + ) + self.compare_by_show( + cdf.select(CF.create_map(cdf.d, "e", "e", CF.lit(1))), + sdf.select(SF.create_map(sdf.d, "e", "e", SF.lit(1))), + ) + def test_string_functions(self): from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org