Repository: arrow Updated Branches: refs/heads/master 30bb0d97d -> 4226adfbc
ARROW-515: [Python] Add read_all methods to FileReader, StreamReader Stacked on top of ARROW-514 Author: Wes McKinney <wes.mckin...@twosigma.com> Closes #307 from wesm/ARROW-515 and squashes the following commits: 6f2185c [Wes McKinney] Add read_all method to StreamReader, FileReader Project: http://git-wip-us.apache.org/repos/asf/arrow/repo Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/4226adfb Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/4226adfb Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/4226adfb Branch: refs/heads/master Commit: 4226adfbc6b3dff10b3fe7a6691b30bcc94140bd Parents: 30bb0d9 Author: Wes McKinney <wes.mckin...@twosigma.com> Authored: Fri Jan 27 10:46:34 2017 +0100 Committer: Uwe L. Korn <uw...@xhochy.com> Committed: Fri Jan 27 10:46:34 2017 +0100 ---------------------------------------------------------------------- python/pyarrow/io.pyx | 44 ++++++++++++++++++++++++++++++++++- python/pyarrow/table.pyx | 4 +--- python/pyarrow/tests/test_ipc.py | 19 +++++++++++++++ 3 files changed, 63 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/arrow/blob/4226adfb/python/pyarrow/io.pyx ---------------------------------------------------------------------- diff --git a/python/pyarrow/io.pyx b/python/pyarrow/io.pyx index e5f8b7a..8b56508 100644 --- a/python/pyarrow/io.pyx +++ b/python/pyarrow/io.pyx @@ -34,7 +34,8 @@ cimport pyarrow.includes.pyarrow as pyarrow from pyarrow.compat import frombytes, tobytes, encode_file_path from pyarrow.error cimport check_status from pyarrow.schema cimport Schema -from pyarrow.table cimport RecordBatch, batch_from_cbatch +from pyarrow.table cimport (RecordBatch, batch_from_cbatch, + table_from_ctable) cimport cpython as cp @@ -936,6 +937,27 @@ cdef class _StreamReader: return batch_from_cbatch(batch) + def read_all(self): + """ + Read all record batches as a pyarrow.Table + """ + cdef: + vector[shared_ptr[CRecordBatch]] batches + shared_ptr[CRecordBatch] batch + shared_ptr[CTable] table + c_string name = b'' + + with nogil: + while True: + check_status(self.reader.get().GetNextRecordBatch(&batch)) + if batch.get() == NULL: + break + batches.push_back(batch) + + check_status(CTable.FromRecordBatches(name, batches, &table)) + + return table_from_ctable(table) + cdef class _FileWriter(_StreamWriter): @@ -997,3 +1019,23 @@ cdef class _FileReader: # TODO(wesm): ARROW-503: Function was renamed. Remove after a period of # time has passed get_record_batch = get_batch + + def read_all(self): + """ + Read all record batches as a pyarrow.Table + """ + cdef: + vector[shared_ptr[CRecordBatch]] batches + shared_ptr[CTable] table + c_string name = b'' + int i, nbatches + + nbatches = self.num_record_batches + + batches.resize(nbatches) + with nogil: + for i in range(nbatches): + check_status(self.reader.get().GetRecordBatch(i, &batches[i])) + check_status(CTable.FromRecordBatches(name, batches, &table)) + + return table_from_ctable(table) http://git-wip-us.apache.org/repos/asf/arrow/blob/4226adfb/python/pyarrow/table.pyx ---------------------------------------------------------------------- diff --git a/python/pyarrow/table.pyx b/python/pyarrow/table.pyx index 9242330..1707210 100644 --- a/python/pyarrow/table.pyx +++ b/python/pyarrow/table.pyx @@ -690,9 +690,7 @@ cdef class Table: with nogil: check_status(CTable.FromRecordBatches(c_name, c_batches, &c_table)) - table = Table() - table.init(c_table) - return table + return table_from_ctable(c_table) def to_pandas(self, nthreads=None): """ http://git-wip-us.apache.org/repos/asf/arrow/blob/4226adfb/python/pyarrow/tests/test_ipc.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index 8ca464f..665a63b 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -83,6 +83,16 @@ class TestFile(MessagingTest, unittest.TestCase): batch = reader.get_batch(i) assert batches[i].equals(batch) + def test_read_all(self): + batches = self.write_batches() + file_contents = self._get_source() + + reader = pa.FileReader(file_contents) + + result = reader.read_all() + expected = pa.Table.from_batches(batches) + assert result.equals(expected) + class TestStream(MessagingTest, unittest.TestCase): @@ -104,6 +114,15 @@ class TestStream(MessagingTest, unittest.TestCase): with pytest.raises(StopIteration): reader.get_next_batch() + def test_read_all(self): + batches = self.write_batches() + file_contents = self._get_source() + reader = pa.StreamReader(file_contents) + + result = reader.read_all() + expected = pa.Table.from_batches(batches) + assert result.equals(expected) + class TestInMemoryFile(TestFile):