This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d53ba6276f4 [SPARK-41436][CONNECT][PYTHON] Implement `collection` 
functions: A~C
d53ba6276f4 is described below

commit d53ba6276f407b6e090c9610b85a56d047e56f73
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Dec 8 16:59:13 2022 +0800

    [SPARK-41436][CONNECT][PYTHON] Implement `collection` functions: A~C
    
    ### What changes were proposed in this pull request?
     Implement [`collection` 
functions](https://github.com/apache/spark/blob/master/python/docs/source/reference/pyspark.sql/functions.rst#collection-functions)
 alphabetically, this PR contains `A` ~ `C`  except:
    
    - `aggregate`, `array_sort` - need the support of LambdaFunction Expression
    - the `int count` in `array_repeat` - need to support datatype in 
LiteralExpression in the Python Client
    
    ### Why are the changes needed?
    
    For API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    new APIs
    
    ### How was this patch tested?
    added UT
    
    Closes #38961 from zhengruifeng/connect_function_collect_1.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/functions.py            | 608 ++++++++++++++++++++-
 .../sql/tests/connect/test_connect_function.py     | 156 ++++++
 2 files changed, 763 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index 8b36647ae5b..7eb17bd89ac 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -90,7 +90,10 @@ column = col
 
 
 def lit(col: Any) -> Column:
-    return Column(LiteralExpression(col))
+    if isinstance(col, Column):
+        return col
+    else:
+        return Column(LiteralExpression(col))
 
 
 # def bitwiseNOT(col: "ColumnOrName") -> Column:
@@ -3208,6 +3211,609 @@ def variance(col: "ColumnOrName") -> Column:
     return var_samp(col)
 
 
+# Collection Functions
+
+
+# TODO(SPARK-41434): need to support LambdaFunction Expression first
+# def aggregate(
+#         col: "ColumnOrName",
+#         initialValue: "ColumnOrName",
+#         merge: Callable[[Column, Column], Column],
+#         finish: Optional[Callable[[Column], Column]] = None,
+# ) -> Column:
+#     """
+#     Applies a binary operator to an initial state and all elements in the 
array,
+#     and reduces this to a single state. The final state is converted into 
the final result
+#     by applying a finish function.
+#
+#     Both functions can use methods of :class:`~pyspark.sql.Column`, 
functions defined in
+#     :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
+#     Python ``UserDefinedFunctions`` are not supported
+#     (`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__).
+#
+#     .. versionadded:: 3.1.0
+#
+#     Parameters
+#     ----------
+#     col : :class:`~pyspark.sql.Column` or str
+#         name of column or expression
+#     initialValue : :class:`~pyspark.sql.Column` or str
+#         initial value. Name of column or expression
+#     merge : function
+#         a binary function ``(acc: Column, x: Column) -> Column...`` 
returning expression
+#         of the same type as ``zero``
+#     finish : function
+#         an optional unary function ``(x: Column) -> Column: ...``
+#         used to convert accumulated value.
+#
+#     Returns
+#     -------
+#     :class:`~pyspark.sql.Column`
+#         final value after aggregate function is applied.
+#
+#     Examples
+#     --------
+#     >>> df = spark.createDataFrame([(1, [20.0, 4.0, 2.0, 6.0, 10.0])], 
("id", "values"))
+#     >>> df.select(aggregate("values", lit(0.0), lambda acc, x: acc + 
x).alias("sum")).show()
+#     +----+
+#     | sum|
+#     +----+
+#     |42.0|
+#     +----+
+#
+#     >>> def merge(acc, x):
+#     ...     count = acc.count + 1
+#     ...     sum = acc.sum + x
+#     ...     return struct(count.alias("count"), sum.alias("sum"))
+#     >>> df.select(
+#     ...     aggregate(
+#     ...         "values",
+#     ...         struct(lit(0).alias("count"), lit(0.0).alias("sum")),
+#     ...         merge,
+#     ...         lambda acc: acc.sum / acc.count,
+#     ...     ).alias("mean")
+#     ... ).show()
+#     +----+
+#     |mean|
+#     +----+
+#     | 8.4|
+#     +----+
+#     """
+#     if finish is not None:
+#         return _invoke_higher_order_function("ArrayAggregate", [col, 
initialValue],
+#           [merge, finish])
+#
+#     else:
+#         return _invoke_higher_order_function("ArrayAggregate", [col, 
initialValue],
+#           [merge])
+
+
+def array(*cols: Union["ColumnOrName", List["ColumnOrName"], 
Tuple["ColumnOrName", ...]]) -> Column:
+    """Creates a new array column.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    cols : :class:`~pyspark.sql.Column` or str
+        column names or :class:`~pyspark.sql.Column`\\s that have
+        the same data type.
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        a column of array type.
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age"))
+    >>> df.select(array('age', 'age').alias("arr")).collect()
+    [Row(arr=[2, 2]), Row(arr=[5, 5])]
+    >>> df.select(array([df.age, df.age]).alias("arr")).collect()
+    [Row(arr=[2, 2]), Row(arr=[5, 5])]
+    >>> df.select(array('age', 'age').alias("col")).printSchema()
+    root
+     |-- col: array (nullable = false)
+     |    |-- element: long (containsNull = true)
+    """
+    if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)):
+        cols = cols[0]  # type: ignore[assignment]
+    return _invoke_function_over_columns("array", *cols)  # type: 
ignore[arg-type]
+
+
+def array_contains(col: "ColumnOrName", value: Any) -> Column:
+    """
+    Collection function: returns null if the array is null, true if the array 
contains the
+    given value, and false otherwise.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        name of column containing array
+    value :
+        value or column to check for in array
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        a column of Boolean type.
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
+    >>> df.select(array_contains(df.data, "a")).collect()
+    [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
+    >>> df.select(array_contains(df.data, lit("a"))).collect()
+    [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
+    """
+    return _invoke_function("array_contains", _to_col(col), lit(value))
+
+
+def array_distinct(col: "ColumnOrName") -> Column:
+    """
+    Collection function: removes duplicate values from the array.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        name of column or expression
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        an array of unique values.
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], 
['data'])
+    >>> df.select(array_distinct(df.data)).collect()
+    [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])]
+    """
+    return _invoke_function_over_columns("array_distinct", col)
+
+
+def array_except(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
+    """
+    Collection function: returns an array of the elements in col1 but not in 
col2,
+    without duplicates.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    col1 : :class:`~pyspark.sql.Column` or str
+        name of column containing array
+    col2 : :class:`~pyspark.sql.Column` or str
+        name of column containing array
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        an array of values from first array that are not in the second.
+
+    Examples
+    --------
+    >>> from pyspark.sql import Row
+    >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", 
"f"])])
+    >>> df.select(array_except(df.c1, df.c2)).collect()
+    [Row(array_except(c1, c2)=['b'])]
+    """
+    return _invoke_function_over_columns("array_except", col1, col2)
+
+
+def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
+    """
+    Collection function: returns an array of the elements in the intersection 
of col1 and col2,
+    without duplicates.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    col1 : :class:`~pyspark.sql.Column` or str
+        name of column containing array
+    col2 : :class:`~pyspark.sql.Column` or str
+        name of column containing array
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        an array of values in the intersection of two arrays.
+
+    Examples
+    --------
+    >>> from pyspark.sql import Row
+    >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", 
"f"])])
+    >>> df.select(array_intersect(df.c1, df.c2)).collect()
+    [Row(array_intersect(c1, c2)=['a', 'c'])]
+    """
+    return _invoke_function_over_columns("array_intersect", col1, col2)
+
+
+def array_join(
+    col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None
+) -> Column:
+    """
+    Concatenates the elements of `column` using the `delimiter`. Null values 
are replaced with
+    `null_replacement` if set, otherwise they are ignored.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        target column to work on.
+    delimiter : str
+        delimiter used to concatenate elements
+    null_replacement : str, optional
+        if set then null values will be replaced by this value
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        a column of string type. Concatenated values.
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], 
['data'])
+    >>> df.select(array_join(df.data, ",").alias("joined")).collect()
+    [Row(joined='a,b,c'), Row(joined='a')]
+    >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect()
+    [Row(joined='a,b,c'), Row(joined='a,NULL')]
+    """
+    if null_replacement is None:
+        return _invoke_function("array_join", _to_col(col), lit(delimiter))
+    else:
+        return _invoke_function("array_join", _to_col(col), lit(delimiter), 
lit(null_replacement))
+
+
+def array_max(col: "ColumnOrName") -> Column:
+    """
+    Collection function: returns the maximum value of the array.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        name of column or expression
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        maximum value of an array.
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
+    >>> df.select(array_max(df.data).alias('max')).collect()
+    [Row(max=3), Row(max=10)]
+    """
+    return _invoke_function_over_columns("array_max", col)
+
+
+def array_min(col: "ColumnOrName") -> Column:
+    """
+    Collection function: returns the minimum value of the array.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        name of column or expression
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        minimum value of array.
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
+    >>> df.select(array_min(df.data).alias('min')).collect()
+    [Row(min=1), Row(min=-1)]
+    """
+    return _invoke_function_over_columns("array_min", col)
+
+
+def array_position(col: "ColumnOrName", value: Any) -> Column:
+    """
+    Collection function: Locates the position of the first occurrence of the 
given value
+    in the given array. Returns null if either of the arguments are null.
+
+    .. versionadded:: 3.4.0
+
+    Notes
+    -----
+    The position is not zero based, but 1 based index. Returns 0 if the given
+    value could not be found in the array.
+
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        target column to work on.
+    value : Any
+        value to look for.
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        position of the value in the given array if found and 0 otherwise.
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data'])
+    >>> df.select(array_position(df.data, "a")).collect()
+    [Row(array_position(data, a)=3), Row(array_position(data, a)=0)]
+    """
+    return _invoke_function("array_position", _to_col(col), lit(value))
+
+
+def array_remove(col: "ColumnOrName", element: Any) -> Column:
+    """
+    Collection function: Remove all elements that equal to element from the 
given array.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        name of column containing array
+    element :
+        element to be removed from the array
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        an array excluding given value.
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data'])
+    >>> df.select(array_remove(df.data, 1)).collect()
+    [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])]
+    """
+    return _invoke_function("array_remove", _to_col(col), lit(element))
+
+
+def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> 
Column:
+    """
+    Collection function: creates an array containing a column repeated count 
times.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        column name or column that contains the element to be repeated
+    count : :class:`~pyspark.sql.Column` or str or int
+        column name, column, or int containing the number of times to repeat 
the first argument
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        an array of repeated elements.
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([('ab',)], ['data'])
+    >>> df.select(array_repeat(df.data, 3).alias('r')).collect()
+    [Row(r=['ab', 'ab', 'ab'])]
+    """
+    _count = lit(count) if isinstance(count, int) else _to_col(count)
+
+    return _invoke_function("array_repeat", _to_col(col), _count)
+
+
+# TODO(SPARK-41434): need to support LambdaFunction Expression first
+# def array_sort(
+#         col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], 
Column]] = None
+# ) -> Column:
+#     """
+#     Collection function: sorts the input array in ascending order. The 
elements of the input array
+#     must be orderable. Null elements will be placed at the end of the 
returned array.
+#
+#     .. versionadded:: 2.4.0
+#     .. versionchanged:: 3.4.0
+#         Can take a `comparator` function.
+#
+#     Parameters
+#     ----------
+#     col : :class:`~pyspark.sql.Column` or str
+#         name of column or expression
+#     comparator : callable, optional
+#         A binary ``(Column, Column) -> Column: ...``.
+#         The comparator will take two
+#         arguments representing two elements of the array. It returns a 
negative integer, 0, or a
+#         positive integer as the first element is less than, equal to, or 
greater than the second
+#         element. If the comparator function returns null, the function will 
fail and raise an
+#         error.
+#
+#     Returns
+#     -------
+#     :class:`~pyspark.sql.Column`
+#         sorted array.
+#
+#     Examples
+#     --------
+#     >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], 
['data'])
+#     >>> df.select(array_sort(df.data).alias('r')).collect()
+#     [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])]
+#     >>> df = spark.createDataFrame([(["foo", "foobar", None, 
"bar"],),(["foo"],),([],)], ['data'])
+#     >>> df.select(array_sort(
+#     ...     "data",
+#     ...     lambda x, y: when(x.isNull() | y.isNull(), 
lit(0)).otherwise(length(y) - length(x))
+#     ... ).alias("r")).collect()
+#     [Row(r=['foobar', 'foo', None, 'bar']), Row(r=['foo']), Row(r=[])]
+#     """
+#     if comparator is None:
+#         return _invoke_function_over_columns("array_sort", col)
+#     else:
+#         return _invoke_higher_order_function("ArraySort", [col], 
[comparator])
+
+
+def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
+    """
+    Collection function: returns an array of the elements in the union of col1 
and col2,
+    without duplicates.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    col1 : :class:`~pyspark.sql.Column` or str
+        name of column containing array
+    col2 : :class:`~pyspark.sql.Column` or str
+        name of column containing array
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        an array of values in union of two arrays.
+
+    Examples
+    --------
+    >>> from pyspark.sql import Row
+    >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", 
"f"])])
+    >>> df.select(array_union(df.c1, df.c2)).collect()
+    [Row(array_union(c1, c2)=['b', 'a', 'c', 'd', 'f'])]
+    """
+    return _invoke_function_over_columns("array_union", col1, col2)
+
+
+def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column:
+    """
+    Collection function: returns true if the arrays contain any common 
non-null element; if not,
+    returns null if both the arrays are non-empty and any of them contains a 
null element; returns
+    false otherwise.
+
+    .. versionadded:: 3.4.0
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        a column of Boolean type.
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", 
"c"])], ['x', 'y'])
+    >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect()
+    [Row(overlap=True), Row(overlap=False)]
+    """
+    return _invoke_function_over_columns("arrays_overlap", a1, a2)
+
+
+def arrays_zip(*cols: "ColumnOrName") -> Column:
+    """
+    Collection function: Returns a merged array of structs in which the N-th 
struct contains all
+    N-th values of input arrays. If one of the arrays is shorter than others 
then
+    resulting struct type value will be a `null` for missing elements.
+
+    .. versionadded:: 2.4.0
+
+    Parameters
+    ----------
+    cols : :class:`~pyspark.sql.Column` or str
+        columns of arrays to be merged.
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        merged array of entries.
+
+    Examples
+    --------
+    >>> from pyspark.sql.functions import arrays_zip
+    >>> df = spark.createDataFrame([(([1, 2, 3], [2, 4, 6], [3, 6]))], 
['vals1', 'vals2', 'vals3'])
+    >>> df = df.select(arrays_zip(df.vals1, df.vals2, 
df.vals3).alias('zipped'))
+    >>> df.show(truncate=False)
+    +------------------------------------+
+    |zipped                              |
+    +------------------------------------+
+    |[{1, 2, 3}, {2, 4, 6}, {3, 6, null}]|
+    +------------------------------------+
+    >>> df.printSchema()
+    root
+     |-- zipped: array (nullable = true)
+     |    |-- element: struct (containsNull = false)
+     |    |    |-- vals1: long (nullable = true)
+     |    |    |-- vals2: long (nullable = true)
+     |    |    |-- vals3: long (nullable = true)
+    """
+    return _invoke_function_over_columns("arrays_zip", *cols)
+
+
+def concat(*cols: "ColumnOrName") -> Column:
+    """
+    Concatenates multiple input columns together into a single column.
+    The function works with strings, numeric, binary and compatible array 
columns.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    cols : :class:`~pyspark.sql.Column` or str
+        target column or columns to work on.
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        concatenated values. Type of the `Column` depends on input columns' 
type.
+
+    See Also
+    --------
+    :meth:`pyspark.sql.functions.array_join` : to concatenate string columns 
with delimiter
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
+    >>> df = df.select(concat(df.s, df.d).alias('s'))
+    >>> df.collect()
+    [Row(s='abcd123')]
+    >>> df
+    DataFrame[s: string]
+
+    >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, 
[3])], ['a', 'b', 'c'])
+    >>> df = df.select(concat(df.a, df.b, df.c).alias("arr"))
+    >>> df.collect()
+    [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
+    >>> df
+    DataFrame[arr: array<bigint>]
+    """
+    return _invoke_function_over_columns("concat", *cols)
+
+
+def create_map(
+    *cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", 
...]]
+) -> Column:
+    """Creates a new map column.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    cols : :class:`~pyspark.sql.Column` or str
+        column names or :class:`~pyspark.sql.Column`\\s that are
+        grouped as key-value pairs, e.g. (key1, value1, key2, value2, ...).
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age"))
+    >>> df.select(create_map('name', 'age').alias("map")).collect()
+    [Row(map={'Alice': 2}), Row(map={'Bob': 5})]
+    >>> df.select(create_map([df.name, df.age]).alias("map")).collect()
+    [Row(map={'Alice': 2}), Row(map={'Bob': 5})]
+    """
+    if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)):
+        cols = cols[0]  # type: ignore[assignment]
+    return _invoke_function_over_columns("map", *cols)  # type: 
ignore[arg-type]
+
+
 # String/Binary functions
 
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index ee3a927708e..96801c58ca9 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -63,6 +63,24 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase):
     """These test cases exercise the interface to the proto plan
     generation but do not call Spark."""
 
+    def compare_by_show(self, df1: Any, df2: Any):
+        from pyspark.sql.dataframe import DataFrame as SDF
+        from pyspark.sql.connect.dataframe import DataFrame as CDF
+
+        assert isinstance(df1, (SDF, CDF))
+        if isinstance(df1, SDF):
+            str1 = df1._jdf.showString(20, 20, False)
+        else:
+            str1 = df1._show_string(20, 20, False)
+
+        assert isinstance(df2, (SDF, CDF))
+        if isinstance(df2, SDF):
+            str2 = df2._jdf.showString(20, 20, False)
+        else:
+            str2 = df2._show_string(20, 20, False)
+
+        self.assertEqual(str1, str2)
+
     def test_normal_functions(self):
         from pyspark.sql import functions as SF
         from pyspark.sql.connect import functions as CF
@@ -428,6 +446,144 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase):
             .toPandas(),
         )
 
+    def test_collection_functions(self):
+        from pyspark.sql import functions as SF
+        from pyspark.sql.connect import functions as CF
+
+        query = """
+            SELECT * FROM VALUES
+            (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3), 1, 2, 'a'),
+            (ARRAY('x', NULL), NULL, ARRAY(1, 3), 3, 4, 'x'),
+            (NULL, ARRAY(-1, -2, -3), Array(), 5, 6, NULL)
+            AS tab(a, b, c, d, e, f)
+            """
+        # +---------+------------+------------+---+---+----+
+        # |        a|           b|           c|  d|  e|   f|
+        # +---------+------------+------------+---+---+----+
+        # |  [a, ab]|   [1, 2, 3]|[1, null, 3]|  1|  2|   a|
+        # |[x, null]|        null|      [1, 3]|  3|  4|   x|
+        # |     null|[-1, -2, -3]|          []|  5|  6|null|
+        # +---------+------------+------------+---+---+----+
+
+        cdf = self.connect.sql(query)
+        sdf = self.spark.sql(query)
+
+        for cfunc, sfunc in [
+            (CF.array_distinct, SF.array_distinct),
+            (CF.array_max, SF.array_max),
+            (CF.array_min, SF.array_min),
+        ]:
+            self.assert_eq(
+                cdf.select(cfunc("a"), cfunc(cdf.b)).toPandas(),
+                sdf.select(sfunc("a"), sfunc(sdf.b)).toPandas(),
+            )
+
+        for cfunc, sfunc in [
+            (CF.array_except, SF.array_except),
+            (CF.array_intersect, SF.array_intersect),
+            (CF.array_union, SF.array_union),
+            (CF.arrays_overlap, SF.arrays_overlap),
+        ]:
+            self.assert_eq(
+                cdf.select(cfunc("b", cdf.c)).toPandas(),
+                sdf.select(sfunc("b", sdf.c)).toPandas(),
+            )
+
+        for cfunc, sfunc in [
+            (CF.array_position, SF.array_position),
+            (CF.array_remove, SF.array_remove),
+        ]:
+            self.assert_eq(
+                cdf.select(cfunc(cdf.a, "ab")).toPandas(),
+                sdf.select(sfunc(sdf.a, "ab")).toPandas(),
+            )
+
+        # test array
+        self.assert_eq(
+            cdf.select(CF.array(cdf.d, "e")).toPandas(),
+            sdf.select(SF.array(sdf.d, "e")).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(CF.array(cdf.d, "e", CF.lit(99))).toPandas(),
+            sdf.select(SF.array(sdf.d, "e", SF.lit(99))).toPandas(),
+        )
+
+        # test array_contains
+        self.assert_eq(
+            cdf.select(CF.array_contains(cdf.a, "ab")).toPandas(),
+            sdf.select(SF.array_contains(sdf.a, "ab")).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(CF.array_contains(cdf.a, cdf.f)).toPandas(),
+            sdf.select(SF.array_contains(sdf.a, sdf.f)).toPandas(),
+        )
+
+        # test array_join
+        self.assert_eq(
+            cdf.select(
+                CF.array_join(cdf.a, ","), CF.array_join("b", ":"), 
CF.array_join("c", "~")
+            ).toPandas(),
+            sdf.select(
+                SF.array_join(sdf.a, ","), SF.array_join("b", ":"), 
SF.array_join("c", "~")
+            ).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(
+                CF.array_join(cdf.a, ",", "_null_"),
+                CF.array_join("b", ":", ".null."),
+                CF.array_join("c", "~", "NULL"),
+            ).toPandas(),
+            sdf.select(
+                SF.array_join(sdf.a, ",", "_null_"),
+                SF.array_join("b", ":", ".null."),
+                SF.array_join("c", "~", "NULL"),
+            ).toPandas(),
+        )
+
+        # test array_repeat
+        self.assert_eq(
+            cdf.select(CF.array_repeat(cdf.f, "d")).toPandas(),
+            sdf.select(SF.array_repeat(sdf.f, "d")).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(CF.array_repeat("f", cdf.d)).toPandas(),
+            sdf.select(SF.array_repeat("f", sdf.d)).toPandas(),
+        )
+        # TODO: Make Literal contains DataType
+        #   Cannot resolve "array_repeat(f, 3)" due to data type mismatch:
+        #   Parameter 2 requires the "INT" type, however "3" has the type 
"BIGINT".
+        # self.assert_eq(
+        #     cdf.select(CF.array_repeat("f", 3)).toPandas(),
+        #     sdf.select(SF.array_repeat("f", 3)).toPandas(),
+        # )
+
+        # test arrays_zip
+        # TODO: Make toPandas support complex nested types like Array<Struct>
+        # DataFrame.iloc[:, 0] (column name="arrays_zip(b, c)") values are 
different (66.66667 %)
+        # [index]: [0, 1, 2]
+        # [left]:  [[{'b': 1, 'c': 1.0}, {'b': 2, 'c': None}, {'b': 3, 'c': 
3.0}], None,
+        #           [{'b': -1, 'c': None}, {'b': -2, 'c': None}, {'b': -3, 
'c': None}]]
+        # [right]: [[(1, 1), (2, None), (3, 3)], None, [(-1, None), (-2, 
None), (-3, None)]]
+        self.compare_by_show(
+            cdf.select(CF.arrays_zip(cdf.b, "c")),
+            sdf.select(SF.arrays_zip(sdf.b, "c")),
+        )
+
+        # test concat
+        self.assert_eq(
+            cdf.select(CF.concat("d", cdf.e, CF.lit(-1))).toPandas(),
+            sdf.select(SF.concat("d", sdf.e, SF.lit(-1))).toPandas(),
+        )
+
+        # test create_map
+        self.compare_by_show(
+            cdf.select(CF.create_map(cdf.d, cdf.e)), 
sdf.select(SF.create_map(sdf.d, sdf.e))
+        )
+        self.compare_by_show(
+            cdf.select(CF.create_map(cdf.d, "e", "e", CF.lit(1))),
+            sdf.select(SF.create_map(sdf.d, "e", "e", SF.lit(1))),
+        )
+
     def test_string_functions(self):
         from pyspark.sql import functions as SF
         from pyspark.sql.connect import functions as CF


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to