Repository: arrow Updated Branches: refs/heads/master 5aea3a3d9 -> b4e9ba1ae
ARROW-968: [Python] Support slices in RecordBatch.__getitem__ Author: Wes McKinney <[email protected]> Closes #908 from wesm/ARROW-968 and squashes the following commits: 47b71a5d [Wes McKinney] Support slices in RecordBatch.__getitem__ Project: http://git-wip-us.apache.org/repos/asf/arrow/repo Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/b4e9ba1a Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/b4e9ba1a Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/b4e9ba1a Branch: refs/heads/master Commit: b4e9ba1ae68bcc449e4426b7c08d2984ed20c6be Parents: 5aea3a3 Author: Wes McKinney <[email protected]> Authored: Sat Jul 29 11:00:58 2017 -0400 Committer: Wes McKinney <[email protected]> Committed: Sat Jul 29 11:00:58 2017 -0400 ---------------------------------------------------------------------- python/pyarrow/array.pxi | 34 ++++++++++++++++++--------------- python/pyarrow/table.pxi | 9 +++++++-- python/pyarrow/tests/test_table.py | 11 +++++++++-- 3 files changed, 35 insertions(+), 19 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/arrow/blob/b4e9ba1a/python/pyarrow/array.pxi ---------------------------------------------------------------------- diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index efbe36f..67418aa 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -89,6 +89,23 @@ def array(object sequence, DataType type=None, MemoryPool memory_pool=None, return pyarrow_wrap_array(sp_array) +def _normalize_slice(object arrow_obj, slice key): + cdef Py_ssize_t n = len(arrow_obj) + + start = key.start or 0 + while start < 0: + start += n + + stop = key.stop if key.stop is not None else n + while stop < 0: + stop += n + + step = key.step or 1 + if step != 1: + raise IndexError('only slices with step 1 supported') + else: + return arrow_obj.slice(start, stop - start) + cdef class Array: @@ -230,23 +247,10 @@ cdef class Array: raise NotImplemented def __getitem__(self, key): - cdef: - Py_ssize_t n = len(self) + cdef Py_ssize_t n = len(self) if PySlice_Check(key): - start = key.start or 0 - while start < 0: - start += n - - stop = key.stop if key.stop is not None else n - while stop < 0: - stop += n - - step = key.step or 1 - if step != 1: - raise IndexError('only slices with step 1 supported') - else: - return self.slice(start, stop - start) + return _normalize_slice(self, key) while key < 0: key += len(self) http://git-wip-us.apache.org/repos/asf/arrow/blob/b4e9ba1a/python/pyarrow/table.pxi ---------------------------------------------------------------------- diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 6188e90..a9cb064 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -475,8 +475,13 @@ cdef class RecordBatch: ) return pyarrow_wrap_array(self.batch.column(i)) - def __getitem__(self, i): - return self.column(i) + def __getitem__(self, key): + cdef: + Py_ssize_t start, stop + if isinstance(key, slice): + return _normalize_slice(self, key) + else: + return self.column(key) def slice(self, offset=0, length=None): """ http://git-wip-us.apache.org/repos/asf/arrow/blob/b4e9ba1a/python/pyarrow/tests/test_table.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index c2aeda9..28b98f0 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -80,7 +80,7 @@ def test_recordbatch_basics(): batch[2] -def test_recordbatch_slice(): +def test_recordbatch_slice_getitem(): data = [ pa.array(range(5)), pa.array([-10, -5, 0, 5, 10]) @@ -90,7 +90,6 @@ def test_recordbatch_slice(): batch = pa.RecordBatch.from_arrays(data, names) sliced = batch.slice(2) - assert sliced.num_rows == 3 expected = pa.RecordBatch.from_arrays( @@ -111,6 +110,14 @@ def test_recordbatch_slice(): with pytest.raises(IndexError): batch.slice(-1) + # Check __getitem__-based slicing + assert batch.slice(0, 0).equals(batch[:0]) + assert batch.slice(0, 2).equals(batch[:2]) + assert batch.slice(2, 2).equals(batch[2:4]) + assert batch.slice(2, len(batch) - 2).equals(batch[2:]) + assert batch.slice(len(batch) - 2, 2).equals(batch[-2:]) + assert batch.slice(len(batch) - 4, 2).equals(batch[-4:-2]) + def test_recordbatch_from_to_pandas(): data = pd.DataFrame({
