This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new e029559  test: parametrize test_array_functions (#678)
e029559 is described below

commit e029559f314c71d1050df985e0418307b759af4a
Author: Michael J Ward <[email protected]>
AuthorDate: Thu May 23 17:08:53 2024 -0500

    test: parametrize test_array_functions (#678)
    
    test_array_functions now has 56 passing test cases and 1 expected failure 
(`array_slice` being the expected failure Ref #670).
    
    test_array_function_flatten was broken out as a single test because it was 
an outlier in terms of test-input.
    
    test_array_function_obj_tests had a different set of asserts, so was broken 
out for 5 test cases.
    
    Ref #671
---
 python/datafusion/tests/test_functions.py | 414 ++++++++++++++++--------------
 1 file changed, 223 insertions(+), 191 deletions(-)

diff --git a/python/datafusion/tests/test_functions.py 
b/python/datafusion/tests/test_functions.py
index d34e46b..493b6b6 100644
--- a/python/datafusion/tests/test_functions.py
+++ b/python/datafusion/tests/test_functions.py
@@ -209,324 +209,356 @@ def test_math_functions():
     )
 
 
-def test_array_functions():
-    data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
-    ctx = SessionContext()
-    batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], 
names=["arr"])
-    df = ctx.create_dataframe([[batch]])
+def py_indexof(arr, v):
+    try:
+        return arr.index(v) + 1
+    except ValueError:
+        return np.nan
+
 
-    def py_indexof(arr, v):
+def py_arr_remove(arr, v, n=None):
+    new_arr = arr[:]
+    found = 0
+    while found != n:
         try:
-            return arr.index(v) + 1
+            new_arr.remove(v)
+            found += 1
         except ValueError:
-            return np.nan
-
-    def py_arr_remove(arr, v, n=None):
-        new_arr = arr[:]
-        found = 0
-        while found != n:
-            try:
-                new_arr.remove(v)
-                found += 1
-            except ValueError:
-                break
-
-        return new_arr
-
-    def py_arr_replace(arr, from_, to, n=None):
-        new_arr = arr[:]
-        found = 0
-        while found != n:
-            try:
-                idx = new_arr.index(from_)
-                new_arr[idx] = to
-                found += 1
-            except ValueError:
-                break
-
-        return new_arr
-
-    def py_arr_resize(arr, size, value):
-        arr = np.asarray(arr)
-        return np.pad(
-            arr,
-            [(0, size - arr.shape[0])],
-            "constant",
-            constant_values=value,
-        )
+            break
 
-    def py_flatten(arr):
-        result = []
-        for elem in arr:
-            if isinstance(elem, list):
-                result.extend(py_flatten(elem))
-            else:
-                result.append(elem)
-        return result
+    return new_arr
 
-    col = column("arr")
-    test_items = [
+
+def py_arr_replace(arr, from_, to, n=None):
+    new_arr = arr[:]
+    found = 0
+    while found != n:
+        try:
+            idx = new_arr.index(from_)
+            new_arr[idx] = to
+            found += 1
+        except ValueError:
+            break
+
+    return new_arr
+
+
+def py_arr_resize(arr, size, value):
+    arr = np.asarray(arr)
+    return np.pad(
+        arr,
+        [(0, size - arr.shape[0])],
+        "constant",
+        constant_values=value,
+    )
+
+
+def py_flatten(arr):
+    result = []
+    for elem in arr:
+        if isinstance(elem, list):
+            result.extend(py_flatten(elem))
+        else:
+            result.append(elem)
+    return result
+
+
[email protected](
+    ("stmt", "py_expr"),
+    [
         [
-            f.array_append(col, literal(99.0)),
-            lambda: [np.append(arr, 99.0) for arr in data],
+            lambda col: f.array_append(col, literal(99.0)),
+            lambda data: [np.append(arr, 99.0) for arr in data],
         ],
         [
-            f.array_push_back(col, literal(99.0)),
-            lambda: [np.append(arr, 99.0) for arr in data],
+            lambda col: f.array_push_back(col, literal(99.0)),
+            lambda data: [np.append(arr, 99.0) for arr in data],
         ],
         [
-            f.list_append(col, literal(99.0)),
-            lambda: [np.append(arr, 99.0) for arr in data],
+            lambda col: f.list_append(col, literal(99.0)),
+            lambda data: [np.append(arr, 99.0) for arr in data],
         ],
         [
-            f.list_push_back(col, literal(99.0)),
-            lambda: [np.append(arr, 99.0) for arr in data],
+            lambda col: f.list_push_back(col, literal(99.0)),
+            lambda data: [np.append(arr, 99.0) for arr in data],
         ],
         [
-            f.array_concat(col, col),
-            lambda: [np.concatenate([arr, arr]) for arr in data],
+            lambda col: f.array_concat(col, col),
+            lambda data: [np.concatenate([arr, arr]) for arr in data],
         ],
         [
-            f.array_cat(col, col),
-            lambda: [np.concatenate([arr, arr]) for arr in data],
+            lambda col: f.array_cat(col, col),
+            lambda data: [np.concatenate([arr, arr]) for arr in data],
         ],
         [
-            f.array_dims(col),
-            lambda: [[len(r)] for r in data],
+            lambda col: f.array_dims(col),
+            lambda data: [[len(r)] for r in data],
         ],
         [
-            f.array_distinct(col),
-            lambda: [list(set(r)) for r in data],
+            lambda col: f.array_distinct(col),
+            lambda data: [list(set(r)) for r in data],
         ],
         [
-            f.list_distinct(col),
-            lambda: [list(set(r)) for r in data],
+            lambda col: f.list_distinct(col),
+            lambda data: [list(set(r)) for r in data],
         ],
         [
-            f.list_dims(col),
-            lambda: [[len(r)] for r in data],
+            lambda col: f.list_dims(col),
+            lambda data: [[len(r)] for r in data],
         ],
         [
-            f.array_element(col, literal(1)),
-            lambda: [r[0] for r in data],
+            lambda col: f.array_element(col, literal(1)),
+            lambda data: [r[0] for r in data],
         ],
         [
-            f.array_extract(col, literal(1)),
-            lambda: [r[0] for r in data],
+            lambda col: f.array_extract(col, literal(1)),
+            lambda data: [r[0] for r in data],
         ],
         [
-            f.list_element(col, literal(1)),
-            lambda: [r[0] for r in data],
+            lambda col: f.list_element(col, literal(1)),
+            lambda data: [r[0] for r in data],
         ],
         [
-            f.list_extract(col, literal(1)),
-            lambda: [r[0] for r in data],
+            lambda col: f.list_extract(col, literal(1)),
+            lambda data: [r[0] for r in data],
         ],
         [
-            f.array_length(col),
-            lambda: [len(r) for r in data],
+            lambda col: f.array_length(col),
+            lambda data: [len(r) for r in data],
         ],
         [
-            f.list_length(col),
-            lambda: [len(r) for r in data],
+            lambda col: f.list_length(col),
+            lambda data: [len(r) for r in data],
         ],
         [
-            f.array_has(col, literal(1.0)),
-            lambda: [1.0 in r for r in data],
+            lambda col: f.array_has(col, literal(1.0)),
+            lambda data: [1.0 in r for r in data],
         ],
         [
-            f.array_has_all(col, f.make_array(*[literal(v) for v in [1.0, 3.0, 
5.0]])),
-            lambda: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
+            lambda col: f.array_has_all(
+                col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
+            ),
+            lambda data: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in 
data],
         ],
         [
-            f.array_has_any(col, f.make_array(*[literal(v) for v in [1.0, 3.0, 
5.0]])),
-            lambda: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
+            lambda col: f.array_has_any(
+                col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
+            ),
+            lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in 
data],
         ],
         [
-            f.array_position(col, literal(1.0)),
-            lambda: [py_indexof(r, 1.0) for r in data],
+            lambda col: f.array_position(col, literal(1.0)),
+            lambda data: [py_indexof(r, 1.0) for r in data],
         ],
         [
-            f.array_indexof(col, literal(1.0)),
-            lambda: [py_indexof(r, 1.0) for r in data],
+            lambda col: f.array_indexof(col, literal(1.0)),
+            lambda data: [py_indexof(r, 1.0) for r in data],
         ],
         [
-            f.list_position(col, literal(1.0)),
-            lambda: [py_indexof(r, 1.0) for r in data],
+            lambda col: f.list_position(col, literal(1.0)),
+            lambda data: [py_indexof(r, 1.0) for r in data],
         ],
         [
-            f.list_indexof(col, literal(1.0)),
-            lambda: [py_indexof(r, 1.0) for r in data],
+            lambda col: f.list_indexof(col, literal(1.0)),
+            lambda data: [py_indexof(r, 1.0) for r in data],
         ],
         [
-            f.array_positions(col, literal(1.0)),
-            lambda: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in 
data],
+            lambda col: f.array_positions(col, literal(1.0)),
+            lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r 
in data],
         ],
         [
-            f.list_positions(col, literal(1.0)),
-            lambda: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in 
data],
+            lambda col: f.list_positions(col, literal(1.0)),
+            lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r 
in data],
         ],
         [
-            f.array_ndims(col),
-            lambda: [np.array(r).ndim for r in data],
+            lambda col: f.array_ndims(col),
+            lambda data: [np.array(r).ndim for r in data],
         ],
         [
-            f.list_ndims(col),
-            lambda: [np.array(r).ndim for r in data],
+            lambda col: f.list_ndims(col),
+            lambda data: [np.array(r).ndim for r in data],
         ],
         [
-            f.array_prepend(literal(99.0), col),
-            lambda: [np.insert(arr, 0, 99.0) for arr in data],
+            lambda col: f.array_prepend(literal(99.0), col),
+            lambda data: [np.insert(arr, 0, 99.0) for arr in data],
         ],
         [
-            f.array_push_front(literal(99.0), col),
-            lambda: [np.insert(arr, 0, 99.0) for arr in data],
+            lambda col: f.array_push_front(literal(99.0), col),
+            lambda data: [np.insert(arr, 0, 99.0) for arr in data],
         ],
         [
-            f.list_prepend(literal(99.0), col),
-            lambda: [np.insert(arr, 0, 99.0) for arr in data],
+            lambda col: f.list_prepend(literal(99.0), col),
+            lambda data: [np.insert(arr, 0, 99.0) for arr in data],
         ],
         [
-            f.list_push_front(literal(99.0), col),
-            lambda: [np.insert(arr, 0, 99.0) for arr in data],
+            lambda col: f.list_push_front(literal(99.0), col),
+            lambda data: [np.insert(arr, 0, 99.0) for arr in data],
         ],
         [
-            f.array_pop_back(col),
-            lambda: [arr[:-1] for arr in data],
+            lambda col: f.array_pop_back(col),
+            lambda data: [arr[:-1] for arr in data],
         ],
         [
-            f.array_pop_front(col),
-            lambda: [arr[1:] for arr in data],
+            lambda col: f.array_pop_front(col),
+            lambda data: [arr[1:] for arr in data],
         ],
         [
-            f.array_remove(col, literal(3.0)),
-            lambda: [py_arr_remove(arr, 3.0, 1) for arr in data],
+            lambda col: f.array_remove(col, literal(3.0)),
+            lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data],
         ],
         [
-            f.list_remove(col, literal(3.0)),
-            lambda: [py_arr_remove(arr, 3.0, 1) for arr in data],
+            lambda col: f.list_remove(col, literal(3.0)),
+            lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data],
         ],
         [
-            f.array_remove_n(col, literal(3.0), literal(2)),
-            lambda: [py_arr_remove(arr, 3.0, 2) for arr in data],
+            lambda col: f.array_remove_n(col, literal(3.0), literal(2)),
+            lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data],
         ],
         [
-            f.list_remove_n(col, literal(3.0), literal(2)),
-            lambda: [py_arr_remove(arr, 3.0, 2) for arr in data],
+            lambda col: f.list_remove_n(col, literal(3.0), literal(2)),
+            lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data],
         ],
         [
-            f.array_remove_all(col, literal(3.0)),
-            lambda: [py_arr_remove(arr, 3.0) for arr in data],
+            lambda col: f.array_remove_all(col, literal(3.0)),
+            lambda data: [py_arr_remove(arr, 3.0) for arr in data],
         ],
         [
-            f.list_remove_all(col, literal(3.0)),
-            lambda: [py_arr_remove(arr, 3.0) for arr in data],
+            lambda col: f.list_remove_all(col, literal(3.0)),
+            lambda data: [py_arr_remove(arr, 3.0) for arr in data],
         ],
         [
-            f.array_repeat(col, literal(2)),
-            lambda: [[arr] * 2 for arr in data],
+            lambda col: f.array_repeat(col, literal(2)),
+            lambda data: [[arr] * 2 for arr in data],
         ],
         [
-            f.array_replace(col, literal(3.0), literal(4.0)),
-            lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
+            lambda col: f.array_replace(col, literal(3.0), literal(4.0)),
+            lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
         ],
         [
-            f.list_replace(col, literal(3.0), literal(4.0)),
-            lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
+            lambda col: f.list_replace(col, literal(3.0), literal(4.0)),
+            lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
         ],
         [
-            f.array_replace_n(col, literal(3.0), literal(4.0), literal(1)),
-            lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
+            lambda col: f.array_replace_n(col, literal(3.0), literal(4.0), 
literal(1)),
+            lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
         ],
         [
-            f.list_replace_n(col, literal(3.0), literal(4.0), literal(2)),
-            lambda: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data],
+            lambda col: f.list_replace_n(col, literal(3.0), literal(4.0), 
literal(2)),
+            lambda data: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data],
         ],
         [
-            f.array_replace_all(col, literal(3.0), literal(4.0)),
-            lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
+            lambda col: f.array_replace_all(col, literal(3.0), literal(4.0)),
+            lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
         ],
         [
-            f.list_replace_all(col, literal(3.0), literal(4.0)),
-            lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
+            lambda col: f.list_replace_all(col, literal(3.0), literal(4.0)),
+            lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
         ],
         [
-            f.array_slice(col, literal(2), literal(4)),
-            lambda: [arr[1:4] for arr in data],
+            lambda col: f.array_slice(col, literal(2), literal(4)),
+            lambda data: [arr[1:4] for arr in data],
         ],
-        # [
-        #     f.list_slice(col, literal(-1), literal(2)),
-        #     lambda: [arr[-1:2] for arr in data],
-        # ],
+        pytest.param(
+            lambda col: f.list_slice(col, literal(-1), literal(2)),
+            lambda data: [arr[-1:2] for arr in data],
+            marks=pytest.mark.xfail,
+        ),
         [
-            f.array_intersect(col, literal([3.0, 4.0])),
-            lambda: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
+            lambda col: f.array_intersect(col, literal([3.0, 4.0])),
+            lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
         ],
         [
-            f.list_intersect(col, literal([3.0, 4.0])),
-            lambda: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
+            lambda col: f.list_intersect(col, literal([3.0, 4.0])),
+            lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
         ],
         [
-            f.array_union(col, literal([12.0, 999.0])),
-            lambda: [np.union1d(arr, [12.0, 999.0]) for arr in data],
+            lambda col: f.array_union(col, literal([12.0, 999.0])),
+            lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data],
         ],
         [
-            f.list_union(col, literal([12.0, 999.0])),
-            lambda: [np.union1d(arr, [12.0, 999.0]) for arr in data],
+            lambda col: f.list_union(col, literal([12.0, 999.0])),
+            lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data],
         ],
         [
-            f.array_except(col, literal([3.0])),
-            lambda: [np.setdiff1d(arr, [3.0]) for arr in data],
+            lambda col: f.array_except(col, literal([3.0])),
+            lambda data: [np.setdiff1d(arr, [3.0]) for arr in data],
         ],
         [
-            f.list_except(col, literal([3.0])),
-            lambda: [np.setdiff1d(arr, [3.0]) for arr in data],
+            lambda col: f.list_except(col, literal([3.0])),
+            lambda data: [np.setdiff1d(arr, [3.0]) for arr in data],
         ],
         [
-            f.array_resize(col, literal(10), literal(0.0)),
-            lambda: [py_arr_resize(arr, 10, 0.0) for arr in data],
+            lambda col: f.array_resize(col, literal(10), literal(0.0)),
+            lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data],
         ],
         [
-            f.list_resize(col, literal(10), literal(0.0)),
-            lambda: [py_arr_resize(arr, 10, 0.0) for arr in data],
+            lambda col: f.list_resize(col, literal(10), literal(0.0)),
+            lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data],
         ],
-        [f.flatten(literal(data)), lambda: [py_flatten(data)]],
         [
-            f.range(literal(1), literal(5), literal(2)),
-            lambda: [np.arange(1, 5, 2)],
+            lambda col: f.range(literal(1), literal(5), literal(2)),
+            lambda data: [np.arange(1, 5, 2)],
         ],
-    ]
+    ],
+)
+def test_array_functions(stmt, py_expr):
+    data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
+    ctx = SessionContext()
+    batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], 
names=["arr"])
+    df = ctx.create_dataframe([[batch]])
 
-    for stmt, py_expr in test_items:
-        query_result = df.select(stmt).collect()[0].column(0)
-        for a, b in zip(query_result, py_expr()):
-            np.testing.assert_array_almost_equal(
-                np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
-            )
+    col = column("arr")
+    query_result = df.select(stmt(col)).collect()[0].column(0)
+    for a, b in zip(query_result, py_expr(data)):
+        np.testing.assert_array_almost_equal(
+            np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
+        )
 
-    obj_test_items = [
+
+def test_array_function_flatten():
+    data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
+    ctx = SessionContext()
+    batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], 
names=["arr"])
+    df = ctx.create_dataframe([[batch]])
+
+    stmt = f.flatten(literal(data))
+    py_expr = [py_flatten(data)]
+    query_result = df.select(stmt).collect()[0].column(0)
+    for a, b in zip(query_result, py_expr):
+        np.testing.assert_array_almost_equal(
+            np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
+        )
+
+
[email protected](
+    ("stmt", "py_expr"),
+    [
         [
-            f.array_to_string(col, literal(",")),
-            lambda: [",".join([str(int(v)) for v in r]) for r in data],
+            f.array_to_string(column("arr"), literal(",")),
+            lambda data: [",".join([str(int(v)) for v in r]) for r in data],
         ],
         [
-            f.array_join(col, literal(",")),
-            lambda: [",".join([str(int(v)) for v in r]) for r in data],
+            f.array_join(column("arr"), literal(",")),
+            lambda data: [",".join([str(int(v)) for v in r]) for r in data],
         ],
         [
-            f.list_to_string(col, literal(",")),
-            lambda: [",".join([str(int(v)) for v in r]) for r in data],
+            f.list_to_string(column("arr"), literal(",")),
+            lambda data: [",".join([str(int(v)) for v in r]) for r in data],
         ],
         [
-            f.list_join(col, literal(",")),
-            lambda: [",".join([str(int(v)) for v in r]) for r in data],
+            f.list_join(column("arr"), literal(",")),
+            lambda data: [",".join([str(int(v)) for v in r]) for r in data],
         ],
-    ]
-
-    for stmt, py_expr in obj_test_items:
-        query_result = np.array(df.select(stmt).collect()[0].column(0))
-        for a, b in zip(query_result, py_expr()):
-            assert a == b
+    ],
+)
+def test_array_function_obj_tests(stmt, py_expr):
+    data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
+    ctx = SessionContext()
+    batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], 
names=["arr"])
+    df = ctx.create_dataframe([[batch]])
+    query_result = np.array(df.select(stmt).collect()[0].column(0))
+    for a, b in zip(query_result, py_expr(data)):
+        assert a == b
 
 
 def test_string_functions(df):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to