This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new e029559 test: parametrize test_array_functions (#678)
e029559 is described below
commit e029559f314c71d1050df985e0418307b759af4a
Author: Michael J Ward <[email protected]>
AuthorDate: Thu May 23 17:08:53 2024 -0500
test: parametrize test_array_functions (#678)
test_array_functions now has 56 passing test cases and 1 expected failure
(`array_slice` being the expected failure Ref #670).
test_array_function_flatten was broken out as a single test because it was
an outlier in terms of test-input.
test_array_function_obj_tests had a different set of asserts, so was broken
out for 5 test cases.
Ref #671
---
python/datafusion/tests/test_functions.py | 414 ++++++++++++++++--------------
1 file changed, 223 insertions(+), 191 deletions(-)
diff --git a/python/datafusion/tests/test_functions.py
b/python/datafusion/tests/test_functions.py
index d34e46b..493b6b6 100644
--- a/python/datafusion/tests/test_functions.py
+++ b/python/datafusion/tests/test_functions.py
@@ -209,324 +209,356 @@ def test_math_functions():
)
-def test_array_functions():
- data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
- ctx = SessionContext()
- batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)],
names=["arr"])
- df = ctx.create_dataframe([[batch]])
+def py_indexof(arr, v):
+ try:
+ return arr.index(v) + 1
+ except ValueError:
+ return np.nan
+
- def py_indexof(arr, v):
+def py_arr_remove(arr, v, n=None):
+ new_arr = arr[:]
+ found = 0
+ while found != n:
try:
- return arr.index(v) + 1
+ new_arr.remove(v)
+ found += 1
except ValueError:
- return np.nan
-
- def py_arr_remove(arr, v, n=None):
- new_arr = arr[:]
- found = 0
- while found != n:
- try:
- new_arr.remove(v)
- found += 1
- except ValueError:
- break
-
- return new_arr
-
- def py_arr_replace(arr, from_, to, n=None):
- new_arr = arr[:]
- found = 0
- while found != n:
- try:
- idx = new_arr.index(from_)
- new_arr[idx] = to
- found += 1
- except ValueError:
- break
-
- return new_arr
-
- def py_arr_resize(arr, size, value):
- arr = np.asarray(arr)
- return np.pad(
- arr,
- [(0, size - arr.shape[0])],
- "constant",
- constant_values=value,
- )
+ break
- def py_flatten(arr):
- result = []
- for elem in arr:
- if isinstance(elem, list):
- result.extend(py_flatten(elem))
- else:
- result.append(elem)
- return result
+ return new_arr
- col = column("arr")
- test_items = [
+
+def py_arr_replace(arr, from_, to, n=None):
+ new_arr = arr[:]
+ found = 0
+ while found != n:
+ try:
+ idx = new_arr.index(from_)
+ new_arr[idx] = to
+ found += 1
+ except ValueError:
+ break
+
+ return new_arr
+
+
+def py_arr_resize(arr, size, value):
+ arr = np.asarray(arr)
+ return np.pad(
+ arr,
+ [(0, size - arr.shape[0])],
+ "constant",
+ constant_values=value,
+ )
+
+
+def py_flatten(arr):
+ result = []
+ for elem in arr:
+ if isinstance(elem, list):
+ result.extend(py_flatten(elem))
+ else:
+ result.append(elem)
+ return result
+
+
[email protected](
+ ("stmt", "py_expr"),
+ [
[
- f.array_append(col, literal(99.0)),
- lambda: [np.append(arr, 99.0) for arr in data],
+ lambda col: f.array_append(col, literal(99.0)),
+ lambda data: [np.append(arr, 99.0) for arr in data],
],
[
- f.array_push_back(col, literal(99.0)),
- lambda: [np.append(arr, 99.0) for arr in data],
+ lambda col: f.array_push_back(col, literal(99.0)),
+ lambda data: [np.append(arr, 99.0) for arr in data],
],
[
- f.list_append(col, literal(99.0)),
- lambda: [np.append(arr, 99.0) for arr in data],
+ lambda col: f.list_append(col, literal(99.0)),
+ lambda data: [np.append(arr, 99.0) for arr in data],
],
[
- f.list_push_back(col, literal(99.0)),
- lambda: [np.append(arr, 99.0) for arr in data],
+ lambda col: f.list_push_back(col, literal(99.0)),
+ lambda data: [np.append(arr, 99.0) for arr in data],
],
[
- f.array_concat(col, col),
- lambda: [np.concatenate([arr, arr]) for arr in data],
+ lambda col: f.array_concat(col, col),
+ lambda data: [np.concatenate([arr, arr]) for arr in data],
],
[
- f.array_cat(col, col),
- lambda: [np.concatenate([arr, arr]) for arr in data],
+ lambda col: f.array_cat(col, col),
+ lambda data: [np.concatenate([arr, arr]) for arr in data],
],
[
- f.array_dims(col),
- lambda: [[len(r)] for r in data],
+ lambda col: f.array_dims(col),
+ lambda data: [[len(r)] for r in data],
],
[
- f.array_distinct(col),
- lambda: [list(set(r)) for r in data],
+ lambda col: f.array_distinct(col),
+ lambda data: [list(set(r)) for r in data],
],
[
- f.list_distinct(col),
- lambda: [list(set(r)) for r in data],
+ lambda col: f.list_distinct(col),
+ lambda data: [list(set(r)) for r in data],
],
[
- f.list_dims(col),
- lambda: [[len(r)] for r in data],
+ lambda col: f.list_dims(col),
+ lambda data: [[len(r)] for r in data],
],
[
- f.array_element(col, literal(1)),
- lambda: [r[0] for r in data],
+ lambda col: f.array_element(col, literal(1)),
+ lambda data: [r[0] for r in data],
],
[
- f.array_extract(col, literal(1)),
- lambda: [r[0] for r in data],
+ lambda col: f.array_extract(col, literal(1)),
+ lambda data: [r[0] for r in data],
],
[
- f.list_element(col, literal(1)),
- lambda: [r[0] for r in data],
+ lambda col: f.list_element(col, literal(1)),
+ lambda data: [r[0] for r in data],
],
[
- f.list_extract(col, literal(1)),
- lambda: [r[0] for r in data],
+ lambda col: f.list_extract(col, literal(1)),
+ lambda data: [r[0] for r in data],
],
[
- f.array_length(col),
- lambda: [len(r) for r in data],
+ lambda col: f.array_length(col),
+ lambda data: [len(r) for r in data],
],
[
- f.list_length(col),
- lambda: [len(r) for r in data],
+ lambda col: f.list_length(col),
+ lambda data: [len(r) for r in data],
],
[
- f.array_has(col, literal(1.0)),
- lambda: [1.0 in r for r in data],
+ lambda col: f.array_has(col, literal(1.0)),
+ lambda data: [1.0 in r for r in data],
],
[
- f.array_has_all(col, f.make_array(*[literal(v) for v in [1.0, 3.0,
5.0]])),
- lambda: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
+ lambda col: f.array_has_all(
+ col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
+ ),
+ lambda data: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in
data],
],
[
- f.array_has_any(col, f.make_array(*[literal(v) for v in [1.0, 3.0,
5.0]])),
- lambda: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
+ lambda col: f.array_has_any(
+ col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
+ ),
+ lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in
data],
],
[
- f.array_position(col, literal(1.0)),
- lambda: [py_indexof(r, 1.0) for r in data],
+ lambda col: f.array_position(col, literal(1.0)),
+ lambda data: [py_indexof(r, 1.0) for r in data],
],
[
- f.array_indexof(col, literal(1.0)),
- lambda: [py_indexof(r, 1.0) for r in data],
+ lambda col: f.array_indexof(col, literal(1.0)),
+ lambda data: [py_indexof(r, 1.0) for r in data],
],
[
- f.list_position(col, literal(1.0)),
- lambda: [py_indexof(r, 1.0) for r in data],
+ lambda col: f.list_position(col, literal(1.0)),
+ lambda data: [py_indexof(r, 1.0) for r in data],
],
[
- f.list_indexof(col, literal(1.0)),
- lambda: [py_indexof(r, 1.0) for r in data],
+ lambda col: f.list_indexof(col, literal(1.0)),
+ lambda data: [py_indexof(r, 1.0) for r in data],
],
[
- f.array_positions(col, literal(1.0)),
- lambda: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in
data],
+ lambda col: f.array_positions(col, literal(1.0)),
+ lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r
in data],
],
[
- f.list_positions(col, literal(1.0)),
- lambda: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in
data],
+ lambda col: f.list_positions(col, literal(1.0)),
+ lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r
in data],
],
[
- f.array_ndims(col),
- lambda: [np.array(r).ndim for r in data],
+ lambda col: f.array_ndims(col),
+ lambda data: [np.array(r).ndim for r in data],
],
[
- f.list_ndims(col),
- lambda: [np.array(r).ndim for r in data],
+ lambda col: f.list_ndims(col),
+ lambda data: [np.array(r).ndim for r in data],
],
[
- f.array_prepend(literal(99.0), col),
- lambda: [np.insert(arr, 0, 99.0) for arr in data],
+ lambda col: f.array_prepend(literal(99.0), col),
+ lambda data: [np.insert(arr, 0, 99.0) for arr in data],
],
[
- f.array_push_front(literal(99.0), col),
- lambda: [np.insert(arr, 0, 99.0) for arr in data],
+ lambda col: f.array_push_front(literal(99.0), col),
+ lambda data: [np.insert(arr, 0, 99.0) for arr in data],
],
[
- f.list_prepend(literal(99.0), col),
- lambda: [np.insert(arr, 0, 99.0) for arr in data],
+ lambda col: f.list_prepend(literal(99.0), col),
+ lambda data: [np.insert(arr, 0, 99.0) for arr in data],
],
[
- f.list_push_front(literal(99.0), col),
- lambda: [np.insert(arr, 0, 99.0) for arr in data],
+ lambda col: f.list_push_front(literal(99.0), col),
+ lambda data: [np.insert(arr, 0, 99.0) for arr in data],
],
[
- f.array_pop_back(col),
- lambda: [arr[:-1] for arr in data],
+ lambda col: f.array_pop_back(col),
+ lambda data: [arr[:-1] for arr in data],
],
[
- f.array_pop_front(col),
- lambda: [arr[1:] for arr in data],
+ lambda col: f.array_pop_front(col),
+ lambda data: [arr[1:] for arr in data],
],
[
- f.array_remove(col, literal(3.0)),
- lambda: [py_arr_remove(arr, 3.0, 1) for arr in data],
+ lambda col: f.array_remove(col, literal(3.0)),
+ lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data],
],
[
- f.list_remove(col, literal(3.0)),
- lambda: [py_arr_remove(arr, 3.0, 1) for arr in data],
+ lambda col: f.list_remove(col, literal(3.0)),
+ lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data],
],
[
- f.array_remove_n(col, literal(3.0), literal(2)),
- lambda: [py_arr_remove(arr, 3.0, 2) for arr in data],
+ lambda col: f.array_remove_n(col, literal(3.0), literal(2)),
+ lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data],
],
[
- f.list_remove_n(col, literal(3.0), literal(2)),
- lambda: [py_arr_remove(arr, 3.0, 2) for arr in data],
+ lambda col: f.list_remove_n(col, literal(3.0), literal(2)),
+ lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data],
],
[
- f.array_remove_all(col, literal(3.0)),
- lambda: [py_arr_remove(arr, 3.0) for arr in data],
+ lambda col: f.array_remove_all(col, literal(3.0)),
+ lambda data: [py_arr_remove(arr, 3.0) for arr in data],
],
[
- f.list_remove_all(col, literal(3.0)),
- lambda: [py_arr_remove(arr, 3.0) for arr in data],
+ lambda col: f.list_remove_all(col, literal(3.0)),
+ lambda data: [py_arr_remove(arr, 3.0) for arr in data],
],
[
- f.array_repeat(col, literal(2)),
- lambda: [[arr] * 2 for arr in data],
+ lambda col: f.array_repeat(col, literal(2)),
+ lambda data: [[arr] * 2 for arr in data],
],
[
- f.array_replace(col, literal(3.0), literal(4.0)),
- lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
+ lambda col: f.array_replace(col, literal(3.0), literal(4.0)),
+ lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
],
[
- f.list_replace(col, literal(3.0), literal(4.0)),
- lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
+ lambda col: f.list_replace(col, literal(3.0), literal(4.0)),
+ lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
],
[
- f.array_replace_n(col, literal(3.0), literal(4.0), literal(1)),
- lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
+ lambda col: f.array_replace_n(col, literal(3.0), literal(4.0),
literal(1)),
+ lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
],
[
- f.list_replace_n(col, literal(3.0), literal(4.0), literal(2)),
- lambda: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data],
+ lambda col: f.list_replace_n(col, literal(3.0), literal(4.0),
literal(2)),
+ lambda data: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data],
],
[
- f.array_replace_all(col, literal(3.0), literal(4.0)),
- lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
+ lambda col: f.array_replace_all(col, literal(3.0), literal(4.0)),
+ lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
],
[
- f.list_replace_all(col, literal(3.0), literal(4.0)),
- lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
+ lambda col: f.list_replace_all(col, literal(3.0), literal(4.0)),
+ lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
],
[
- f.array_slice(col, literal(2), literal(4)),
- lambda: [arr[1:4] for arr in data],
+ lambda col: f.array_slice(col, literal(2), literal(4)),
+ lambda data: [arr[1:4] for arr in data],
],
- # [
- # f.list_slice(col, literal(-1), literal(2)),
- # lambda: [arr[-1:2] for arr in data],
- # ],
+ pytest.param(
+ lambda col: f.list_slice(col, literal(-1), literal(2)),
+ lambda data: [arr[-1:2] for arr in data],
+ marks=pytest.mark.xfail,
+ ),
[
- f.array_intersect(col, literal([3.0, 4.0])),
- lambda: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
+ lambda col: f.array_intersect(col, literal([3.0, 4.0])),
+ lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
],
[
- f.list_intersect(col, literal([3.0, 4.0])),
- lambda: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
+ lambda col: f.list_intersect(col, literal([3.0, 4.0])),
+ lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
],
[
- f.array_union(col, literal([12.0, 999.0])),
- lambda: [np.union1d(arr, [12.0, 999.0]) for arr in data],
+ lambda col: f.array_union(col, literal([12.0, 999.0])),
+ lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data],
],
[
- f.list_union(col, literal([12.0, 999.0])),
- lambda: [np.union1d(arr, [12.0, 999.0]) for arr in data],
+ lambda col: f.list_union(col, literal([12.0, 999.0])),
+ lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data],
],
[
- f.array_except(col, literal([3.0])),
- lambda: [np.setdiff1d(arr, [3.0]) for arr in data],
+ lambda col: f.array_except(col, literal([3.0])),
+ lambda data: [np.setdiff1d(arr, [3.0]) for arr in data],
],
[
- f.list_except(col, literal([3.0])),
- lambda: [np.setdiff1d(arr, [3.0]) for arr in data],
+ lambda col: f.list_except(col, literal([3.0])),
+ lambda data: [np.setdiff1d(arr, [3.0]) for arr in data],
],
[
- f.array_resize(col, literal(10), literal(0.0)),
- lambda: [py_arr_resize(arr, 10, 0.0) for arr in data],
+ lambda col: f.array_resize(col, literal(10), literal(0.0)),
+ lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data],
],
[
- f.list_resize(col, literal(10), literal(0.0)),
- lambda: [py_arr_resize(arr, 10, 0.0) for arr in data],
+ lambda col: f.list_resize(col, literal(10), literal(0.0)),
+ lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data],
],
- [f.flatten(literal(data)), lambda: [py_flatten(data)]],
[
- f.range(literal(1), literal(5), literal(2)),
- lambda: [np.arange(1, 5, 2)],
+ lambda col: f.range(literal(1), literal(5), literal(2)),
+ lambda data: [np.arange(1, 5, 2)],
],
- ]
+ ],
+)
+def test_array_functions(stmt, py_expr):
+ data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
+ ctx = SessionContext()
+ batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)],
names=["arr"])
+ df = ctx.create_dataframe([[batch]])
- for stmt, py_expr in test_items:
- query_result = df.select(stmt).collect()[0].column(0)
- for a, b in zip(query_result, py_expr()):
- np.testing.assert_array_almost_equal(
- np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
- )
+ col = column("arr")
+ query_result = df.select(stmt(col)).collect()[0].column(0)
+ for a, b in zip(query_result, py_expr(data)):
+ np.testing.assert_array_almost_equal(
+ np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
+ )
- obj_test_items = [
+
+def test_array_function_flatten():
+ data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
+ ctx = SessionContext()
+ batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)],
names=["arr"])
+ df = ctx.create_dataframe([[batch]])
+
+ stmt = f.flatten(literal(data))
+ py_expr = [py_flatten(data)]
+ query_result = df.select(stmt).collect()[0].column(0)
+ for a, b in zip(query_result, py_expr):
+ np.testing.assert_array_almost_equal(
+ np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
+ )
+
+
[email protected](
+ ("stmt", "py_expr"),
+ [
[
- f.array_to_string(col, literal(",")),
- lambda: [",".join([str(int(v)) for v in r]) for r in data],
+ f.array_to_string(column("arr"), literal(",")),
+ lambda data: [",".join([str(int(v)) for v in r]) for r in data],
],
[
- f.array_join(col, literal(",")),
- lambda: [",".join([str(int(v)) for v in r]) for r in data],
+ f.array_join(column("arr"), literal(",")),
+ lambda data: [",".join([str(int(v)) for v in r]) for r in data],
],
[
- f.list_to_string(col, literal(",")),
- lambda: [",".join([str(int(v)) for v in r]) for r in data],
+ f.list_to_string(column("arr"), literal(",")),
+ lambda data: [",".join([str(int(v)) for v in r]) for r in data],
],
[
- f.list_join(col, literal(",")),
- lambda: [",".join([str(int(v)) for v in r]) for r in data],
+ f.list_join(column("arr"), literal(",")),
+ lambda data: [",".join([str(int(v)) for v in r]) for r in data],
],
- ]
-
- for stmt, py_expr in obj_test_items:
- query_result = np.array(df.select(stmt).collect()[0].column(0))
- for a, b in zip(query_result, py_expr()):
- assert a == b
+ ],
+)
+def test_array_function_obj_tests(stmt, py_expr):
+ data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
+ ctx = SessionContext()
+ batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)],
names=["arr"])
+ df = ctx.create_dataframe([[batch]])
+ query_result = np.array(df.select(stmt).collect()[0].column(0))
+ for a, b in zip(query_result, py_expr(data)):
+ assert a == b
def test_string_functions(df):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]