This is an automated email from the ASF dual-hosted git repository.
jorisvandenbossche pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 7d834d65c3 GH-36709: [Python] Allow to specify use_threads=False in
Table.group_by to have stable ordering (#36768)
7d834d65c3 is described below
commit 7d834d65c37c17d1c19bfb497eadb983893c9ea0
Author: Joris Van den Bossche <[email protected]>
AuthorDate: Thu Oct 5 09:21:56 2023 +0200
GH-36709: [Python] Allow to specify use_threads=False in Table.group_by to
have stable ordering (#36768)
### Rationale for this change
Add a `use_threads` keyword to the `group_by` method on Table, and passes
this through to the Declaration.to_table call. This also allows to specify
`use_threads=False` to get stable ordering of the output, and which is also
required to specify for certain aggregations (eg `"first"` will fail with the
default of `use_threads=True`)
### Are these changes tested?
Yes, added a test (similar to the one we have for this for `filter`), that
would fail (>50% of the times) if the output was no longer ordered.
* Closes: #36709
Authored-by: Joris Van den Bossche <[email protected]>
Signed-off-by: Joris Van den Bossche <[email protected]>
---
python/pyarrow/acero.py | 4 ++--
python/pyarrow/table.pxi | 20 +++++++++++++++-----
python/pyarrow/tests/test_exec_plan.py | 14 ++++++++++++++
python/pyarrow/tests/test_table.py | 15 +++++++++++++++
4 files changed, 46 insertions(+), 7 deletions(-)
diff --git a/python/pyarrow/acero.py b/python/pyarrow/acero.py
index 63da0a3786..0609e45753 100644
--- a/python/pyarrow/acero.py
+++ b/python/pyarrow/acero.py
@@ -299,10 +299,10 @@ def _sort_source(table_or_dataset, sort_keys,
output_type=Table, **kwargs):
raise TypeError("Unsupported output type")
-def _group_by(table, aggregates, keys):
+def _group_by(table, aggregates, keys, use_threads=True):
decl = Declaration.from_sequence([
Declaration("table_source", TableSourceNodeOptions(table)),
Declaration("aggregate", AggregateNodeOptions(aggregates, keys=keys))
])
- return decl.to_table(use_threads=True)
+ return decl.to_table(use_threads=use_threads)
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 2eae38485d..36601130b3 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -4599,8 +4599,9 @@ cdef class Table(_Tabular):
"""
return self.drop_columns(columns)
- def group_by(self, keys):
- """Declare a grouping over the columns of the table.
+ def group_by(self, keys, use_threads=True):
+ """
+ Declare a grouping over the columns of the table.
Resulting grouping can then be used to perform aggregations
with a subsequent ``aggregate()`` method.
@@ -4609,6 +4610,9 @@ cdef class Table(_Tabular):
----------
keys : str or list[str]
Name of the columns that should be used as the grouping key.
+ use_threads : bool, default True
+ Whether to use multithreading or not. When set to True (the
+ default), no stable ordering of the output is guaranteed.
Returns
-------
@@ -4635,7 +4639,7 @@ cdef class Table(_Tabular):
year: [[2020,2022,2021,2019]]
n_legs_sum: [[2,6,104,5]]
"""
- return TableGroupBy(self, keys)
+ return TableGroupBy(self, keys, use_threads=use_threads)
def join(self, right_table, keys, right_keys=None, join_type="left outer",
left_suffix=None, right_suffix=None, coalesce_keys=True,
@@ -5183,6 +5187,9 @@ class TableGroupBy:
Input table to execute the aggregation on.
keys : str or list[str]
Name of the grouped columns.
+ use_threads : bool, default True
+ Whether to use multithreading or not. When set to True (the default),
+ no stable ordering of the output is guaranteed.
Examples
--------
@@ -5208,12 +5215,13 @@ class TableGroupBy:
values_sum: [[3,7,5]]
"""
- def __init__(self, table, keys):
+ def __init__(self, table, keys, use_threads=True):
if isinstance(keys, str):
keys = [keys]
self._table = table
self.keys = keys
+ self._use_threads = use_threads
def aggregate(self, aggregations):
"""
@@ -5328,4 +5336,6 @@ list[tuple(str, str, FunctionOptions)]
aggr_name = "_".join(target) + "_" + func_nohash
group_by_aggrs.append((target, func, opt, aggr_name))
- return _pac()._group_by(self._table, group_by_aggrs, self.keys)
+ return _pac()._group_by(
+ self._table, group_by_aggrs, self.keys,
use_threads=self._use_threads
+ )
diff --git a/python/pyarrow/tests/test_exec_plan.py
b/python/pyarrow/tests/test_exec_plan.py
index 58c618179b..d85a2c2152 100644
--- a/python/pyarrow/tests/test_exec_plan.py
+++ b/python/pyarrow/tests/test_exec_plan.py
@@ -321,3 +321,17 @@ def test_join_extension_array_column():
result = _perform_join(
"left outer", t1, ["colB"], t3, ["colC"])
assert result["colB"] == pa.chunked_array(ext_array)
+
+
+def test_group_by_ordering():
+ # GH-36709 - preserve ordering in groupby by setting use_threads=False
+ table1 = pa.table({'a': [1, 2, 3, 4], 'b': ['a'] * 4})
+ table2 = pa.table({'a': [1, 2, 3, 4], 'b': ['b'] * 4})
+ table = pa.concat_tables([table1, table2])
+
+ for _ in range(50):
+ # 50 seems to consistently cause errors when order is not preserved.
+ # If the order problem is reintroduced this test will become flaky
+ # which is still a signal that the order is not preserved.
+ result = table.group_by("b", use_threads=False).aggregate([])
+ assert result["b"] == pa.chunked_array([["a"], ["b"]])
diff --git a/python/pyarrow/tests/test_table.py
b/python/pyarrow/tests/test_table.py
index f93c6bbc2c..b9e0d69219 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -2175,6 +2175,21 @@ def test_table_group_by():
}
[email protected]
+def test_table_group_by_first():
+ # "first" is an ordered aggregation -> requires to specify
use_threads=False
+ table1 = pa.table({'a': [1, 2, 3, 4], 'b': ['a', 'b'] * 2})
+ table2 = pa.table({'a': [1, 2, 3, 4], 'b': ['b', 'a'] * 2})
+ table = pa.concat_tables([table1, table2])
+
+ with pytest.raises(NotImplementedError):
+ table.group_by("b").aggregate([("a", "first")])
+
+ result = table.group_by("b", use_threads=False).aggregate([("a", "first")])
+ expected = pa.table({"b": ["a", "b"], "a_first": [1, 2]})
+ assert result.equals(expected)
+
+
def test_table_to_recordbatchreader():
table = pa.Table.from_pydict({'x': [1, 2, 3]})
reader = table.to_reader()