Repository: arrow Updated Branches: refs/heads/master 099f61ce5 -> bb0a75885
http://git-wip-us.apache.org/repos/asf/arrow/blob/bb0a7588/python/pyarrow/ipc.pxi ---------------------------------------------------------------------- diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi new file mode 100644 index 0000000..d6df30b --- /dev/null +++ b/python/pyarrow/ipc.pxi @@ -0,0 +1,480 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cdef class Message: + cdef: + unique_ptr[CMessage] message + + def __cinit__(self): + pass + + def __null_check(self): + if self.message.get() == NULL: + raise TypeError('Message improperly initialized (null)') + + property type: + + def __get__(self): + self.__null_check() + return frombytes(FormatMessageType(self.message.get().type())) + + property metadata: + + def __get__(self): + self.__null_check() + return pyarrow_wrap_buffer(self.message.get().metadata()) + + property body: + + def __get__(self): + self.__null_check() + cdef shared_ptr[CBuffer] body = self.message.get().body() + if body.get() == NULL: + return None + else: + return pyarrow_wrap_buffer(body) + + def equals(self, Message other): + """ + Returns True if the message contents (metadata and body) are identical + + Parameters + ---------- + other : Message + + Returns + ------- + are_equal : bool + """ + cdef c_bool result + with nogil: + result = self.message.get().Equals(deref(other.message.get())) + return result + + def serialize(self, memory_pool=None): + """ + Write message to Buffer with length-prefixed metadata, then body + + Parameters + ---------- + memory_pool : MemoryPool, default None + Uses default memory pool if not specified + + Returns + ------- + serialized : Buffer + """ + cdef: + BufferOutputStream stream = BufferOutputStream(memory_pool) + int64_t output_length = 0 + + with nogil: + check_status(self.message.get() + .SerializeTo(stream.wr_file.get(), + &output_length)) + return stream.get_result() + + def __repr__(self): + metadata_len = self.metadata.size + body = self.body + body_len = 0 if body is None else body.size + + return """pyarrow.Message +type: {0} +metadata length: {1} +body length: {2}""".format(self.type, metadata_len, body_len) + + +cdef class MessageReader: + cdef: + unique_ptr[CMessageReader] reader + + def __cinit__(self): + pass + + def __null_check(self): + if self.reader.get() == NULL: + raise TypeError('Message improperly initialized (null)') + + def __repr__(self): + self.__null_check() + return object.__repr__(self) + + @staticmethod + def open_stream(source): + cdef MessageReader result = MessageReader() + cdef shared_ptr[InputStream] in_stream + get_input_stream(source, &in_stream) + with nogil: + result.reader.reset(new CInputStreamMessageReader(in_stream)) + + return result + + def __iter__(self): + while True: + yield self.read_next_message() + + def read_next_message(self): + """ + Read next Message from the stream. Raises StopIteration at end of + stream + """ + cdef Message result = Message() + + with nogil: + check_status(self.reader.get().ReadNextMessage(&result.message)) + + if result.message.get() == NULL: + raise StopIteration + + return result + +# ---------------------------------------------------------------------- +# File and stream readers and writers + +cdef class _RecordBatchWriter: + cdef: + shared_ptr[CRecordBatchWriter] writer + shared_ptr[OutputStream] sink + bint closed + + def __cinit__(self): + self.closed = True + + def __dealloc__(self): + if not self.closed: + self.close() + + def _open(self, sink, Schema schema): + cdef: + shared_ptr[CRecordBatchStreamWriter] writer + + get_writer(sink, &self.sink) + + with nogil: + check_status( + CRecordBatchStreamWriter.Open(self.sink.get(), + schema.sp_schema, + &writer)) + + self.writer = <shared_ptr[CRecordBatchWriter]> writer + self.closed = False + + def write_batch(self, RecordBatch batch): + with nogil: + check_status(self.writer.get() + .WriteRecordBatch(deref(batch.batch))) + + def close(self): + with nogil: + check_status(self.writer.get().Close()) + self.closed = True + + +cdef get_input_stream(object source, shared_ptr[InputStream]* out): + cdef: + shared_ptr[RandomAccessFile] file_handle + + get_reader(source, &file_handle) + out[0] = <shared_ptr[InputStream]> file_handle + + +cdef class _RecordBatchReader: + cdef: + shared_ptr[CRecordBatchReader] reader + + cdef readonly: + Schema schema + + def __cinit__(self): + pass + + def _open(self, source): + cdef: + shared_ptr[InputStream] in_stream + shared_ptr[CRecordBatchStreamReader] reader + + get_input_stream(source, &in_stream) + + with nogil: + check_status(CRecordBatchStreamReader.Open(in_stream, &reader)) + + self.reader = <shared_ptr[CRecordBatchReader]> reader + self.schema = Schema() + self.schema.init_schema(self.reader.get().schema()) + + def __iter__(self): + while True: + yield self.read_next_batch() + + def get_next_batch(self): + import warnings + warnings.warn('Please use read_next_batch instead of ' + 'get_next_batch', FutureWarning) + return self.read_next_batch() + + def read_next_batch(self): + """ + Read next RecordBatch from the stream. Raises StopIteration at end of + stream + """ + cdef shared_ptr[CRecordBatch] batch + + with nogil: + check_status(self.reader.get().ReadNextRecordBatch(&batch)) + + if batch.get() == NULL: + raise StopIteration + + return pyarrow_wrap_batch(batch) + + def read_all(self): + """ + Read all record batches as a pyarrow.Table + """ + cdef: + vector[shared_ptr[CRecordBatch]] batches + shared_ptr[CRecordBatch] batch + shared_ptr[CTable] table + + with nogil: + while True: + check_status(self.reader.get().ReadNextRecordBatch(&batch)) + if batch.get() == NULL: + break + batches.push_back(batch) + + check_status(CTable.FromRecordBatches(batches, &table)) + + return pyarrow_wrap_table(table) + + +cdef class _RecordBatchFileWriter(_RecordBatchWriter): + + def _open(self, sink, Schema schema): + cdef shared_ptr[CRecordBatchFileWriter] writer + get_writer(sink, &self.sink) + + with nogil: + check_status( + CRecordBatchFileWriter.Open(self.sink.get(), schema.sp_schema, + &writer)) + + # Cast to base class, because has same interface + self.writer = <shared_ptr[CRecordBatchWriter]> writer + self.closed = False + + +cdef class _RecordBatchFileReader: + cdef: + shared_ptr[CRecordBatchFileReader] reader + + cdef readonly: + Schema schema + + def __cinit__(self): + pass + + def _open(self, source, footer_offset=None): + cdef shared_ptr[RandomAccessFile] reader + get_reader(source, &reader) + + cdef int64_t offset = 0 + if footer_offset is not None: + offset = footer_offset + + with nogil: + if offset != 0: + check_status(CRecordBatchFileReader.Open2( + reader, offset, &self.reader)) + else: + check_status(CRecordBatchFileReader.Open(reader, &self.reader)) + + self.schema = pyarrow_wrap_schema(self.reader.get().schema()) + + property num_record_batches: + + def __get__(self): + return self.reader.get().num_record_batches() + + def get_batch(self, int i): + cdef shared_ptr[CRecordBatch] batch + + if i < 0 or i >= self.num_record_batches: + raise ValueError('Batch number {0} out of range'.format(i)) + + with nogil: + check_status(self.reader.get().ReadRecordBatch(i, &batch)) + + return pyarrow_wrap_batch(batch) + + # TODO(wesm): ARROW-503: Function was renamed. Remove after a period of + # time has passed + get_record_batch = get_batch + + def read_all(self): + """ + Read all record batches as a pyarrow.Table + """ + cdef: + vector[shared_ptr[CRecordBatch]] batches + shared_ptr[CTable] table + int i, nbatches + + nbatches = self.num_record_batches + + batches.resize(nbatches) + with nogil: + for i in range(nbatches): + check_status(self.reader.get().ReadRecordBatch(i, &batches[i])) + check_status(CTable.FromRecordBatches(batches, &table)) + + return pyarrow_wrap_table(table) + + +def get_tensor_size(Tensor tensor): + """ + Return total size of serialized Tensor including metadata and padding + """ + cdef int64_t size + with nogil: + check_status(GetTensorSize(deref(tensor.tp), &size)) + return size + + +def get_record_batch_size(RecordBatch batch): + """ + Return total size of serialized RecordBatch including metadata and padding + """ + cdef int64_t size + with nogil: + check_status(GetRecordBatchSize(deref(batch.batch), &size)) + return size + + +def write_tensor(Tensor tensor, NativeFile dest): + """ + Write pyarrow.Tensor to pyarrow.NativeFile object its current position + + Parameters + ---------- + tensor : pyarrow.Tensor + dest : pyarrow.NativeFile + + Returns + ------- + bytes_written : int + Total number of bytes written to the file + """ + cdef: + int32_t metadata_length + int64_t body_length + + dest._assert_writeable() + + with nogil: + check_status( + WriteTensor(deref(tensor.tp), dest.wr_file.get(), + &metadata_length, &body_length)) + + return metadata_length + body_length + + +def read_tensor(NativeFile source): + """ + Read pyarrow.Tensor from pyarrow.NativeFile object from current + position. If the file source supports zero copy (e.g. a memory map), then + this operation does not allocate any memory + + Parameters + ---------- + source : pyarrow.NativeFile + + Returns + ------- + tensor : Tensor + """ + cdef: + shared_ptr[CTensor] sp_tensor + + source._assert_readable() + + cdef int64_t offset = source.tell() + with nogil: + check_status(ReadTensor(offset, source.rd_file.get(), &sp_tensor)) + + return pyarrow_wrap_tensor(sp_tensor) + + +def read_message(source): + """ + Read length-prefixed message from file or buffer-like object + + Parameters + ---------- + source : pyarrow.NativeFile, file-like object, or buffer-like object + + Returns + ------- + message : Message + """ + cdef: + Message result = Message() + NativeFile cpp_file + + if not isinstance(source, NativeFile): + if hasattr(source, 'read'): + source = PythonFile(source) + else: + source = BufferReader(source) + + if not isinstance(source, NativeFile): + raise ValueError('Unable to read message from object with type: {0}' + .format(type(source))) + + source._assert_readable() + + cpp_file = source + + with nogil: + check_status(ReadMessage(cpp_file.rd_file.get(), + &result.message)) + + return result + + +def read_record_batch(Message batch_message, Schema schema): + """ + Read RecordBatch from message, given a known schema + + Parameters + ---------- + batch_message : Message + Such as that obtained from read_message + schema : Schema + + Returns + ------- + batch : RecordBatch + """ + cdef shared_ptr[CRecordBatch] result + + with nogil: + check_status(ReadRecordBatch(deref(batch_message.message.get()), + schema.sp_schema, &result)) + + return pyarrow_wrap_batch(result) http://git-wip-us.apache.org/repos/asf/arrow/blob/bb0a7588/python/pyarrow/ipc.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/ipc.py b/python/pyarrow/ipc.py index 6173299..e8ea3ac 100644 --- a/python/pyarrow/ipc.py +++ b/python/pyarrow/ipc.py @@ -18,6 +18,11 @@ # Arrow file and stream reader/writer classes, and other messaging tools import pyarrow as pa + +from pyarrow.lib import (Message, MessageReader, # noqa + read_message, read_record_batch, + read_tensor, write_tensor, + get_record_batch_size, get_tensor_size) import pyarrow.lib as lib @@ -33,10 +38,6 @@ class RecordBatchStreamReader(lib._RecordBatchReader): def __init__(self, source): self._open(source) - def __iter__(self): - while True: - yield self.get_next_batch() - class RecordBatchStreamWriter(lib._RecordBatchWriter): """ @@ -136,7 +137,7 @@ def serialize_pandas(df): """ batch = pa.RecordBatch.from_pandas(df) sink = pa.InMemoryOutputStream() - writer = pa.RecordBatchFileWriter(sink, batch.schema) + writer = pa.RecordBatchStreamWriter(sink, batch.schema) writer.write_batch(batch) writer.close() return sink.get_result() @@ -157,6 +158,6 @@ def deserialize_pandas(buf, nthreads=1): df : pandas.DataFrame """ buffer_reader = pa.BufferReader(buf) - reader = pa.RecordBatchFileReader(buffer_reader) + reader = pa.RecordBatchStreamReader(buffer_reader) table = reader.read_all() return table.to_pandas(nthreads=nthreads) http://git-wip-us.apache.org/repos/asf/arrow/blob/bb0a7588/python/pyarrow/lib.pyx ---------------------------------------------------------------------- diff --git a/python/pyarrow/lib.pyx b/python/pyarrow/lib.pyx index 13c1822..cf8e4df 100644 --- a/python/pyarrow/lib.pyx +++ b/python/pyarrow/lib.pyx @@ -106,9 +106,15 @@ include "array.pxi" # Column, Table, Record Batch include "table.pxi" -# File IO, IPC +# File IO include "io.pxi" +# IPC / Messaging +include "ipc.pxi" + +# Feather format +include "feather.pxi" + #---------------------------------------------------------------------- # Public API http://git-wip-us.apache.org/repos/asf/arrow/blob/bb0a7588/python/pyarrow/pandas_compat.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/pandas_compat.py b/python/pyarrow/pandas_compat.py index a9569b2..c909b3e 100644 --- a/python/pyarrow/pandas_compat.py +++ b/python/pyarrow/pandas_compat.py @@ -22,7 +22,7 @@ import pandas as pd import six import pyarrow as pa -from pyarrow.compat import PY2 +from pyarrow.compat import PY2 # noqa INDEX_LEVEL_NAME_REGEX = re.compile(r'^__index_level_\d+__$') http://git-wip-us.apache.org/repos/asf/arrow/blob/bb0a7588/python/pyarrow/table.pxi ---------------------------------------------------------------------- diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 01e5306..575755d 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -432,9 +432,23 @@ cdef class RecordBatch: return self._schema - def __getitem__(self, i): + def column(self, i): + """ + Select single column from record batcha + + Returns + ------- + column : pyarrow.Array + """ + if not -self.num_columns <= i < self.num_columns: + raise IndexError( + 'Record batch column index {:d} is out of range'.format(i) + ) return pyarrow_wrap_array(self.batch.column(i)) + def __getitem__(self, i): + return self.column(i) + def slice(self, offset=0, length=None): """ Compute zero-copy slice of this RecordBatch http://git-wip-us.apache.org/repos/asf/arrow/blob/bb0a7588/python/pyarrow/tests/conftest.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/tests/conftest.py b/python/pyarrow/tests/conftest.py index fa9608f..da94da9 100644 --- a/python/pyarrow/tests/conftest.py +++ b/python/pyarrow/tests/conftest.py @@ -26,7 +26,7 @@ defaults = { } try: - import pyarrow.parquet + import pyarrow.parquet # noqa defaults['parquet'] = True except ImportError: pass http://git-wip-us.apache.org/repos/asf/arrow/blob/bb0a7588/python/pyarrow/tests/test_array.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index e0a7416..ed81531 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -225,7 +225,7 @@ def test_simple_type_construction(): (pa.timestamp('us', 'UTC'), 'datetimetz'), pytest.mark.xfail((pa.time32('s'), None), raises=NotImplementedError), pytest.mark.xfail((pa.time64('us'), None), raises=NotImplementedError), - ] + ] ) def test_logical_type(type, expected): assert get_logical_type(type) == expected http://git-wip-us.apache.org/repos/asf/arrow/blob/bb0a7588/python/pyarrow/tests/test_convert_builtin.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/tests/test_convert_builtin.py b/python/pyarrow/tests/test_convert_builtin.py index 62592f9..ec26159 100644 --- a/python/pyarrow/tests/test_convert_builtin.py +++ b/python/pyarrow/tests/test_convert_builtin.py @@ -22,6 +22,7 @@ import pyarrow as pa import datetime import decimal + class StrangeIterable: def __init__(self, lst): self.lst = lst @@ -29,6 +30,7 @@ class StrangeIterable: def __iter__(self): return self.lst.__iter__() + class TestConvertIterable(unittest.TestCase): def test_iterable_types(self): @@ -61,6 +63,7 @@ class TestLimitedConvertIterator(unittest.TestCase): arr2 = pa.array((0, 1, 2)) assert arr1.equals(arr2) + class TestConvertSequence(unittest.TestCase): def test_sequence_types(self): http://git-wip-us.apache.org/repos/asf/arrow/blob/bb0a7588/python/pyarrow/tests/test_feather.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/tests/test_feather.py b/python/pyarrow/tests/test_feather.py index 91bf56b..7978ace 100644 --- a/python/pyarrow/tests/test_feather.py +++ b/python/pyarrow/tests/test_feather.py @@ -359,7 +359,8 @@ class TestFeatherReader(unittest.TestCase): expected = df.rename(columns=str) self._check_pandas_roundtrip(df, expected) - @pytest.mark.skipif(not os.path.supports_unicode_filenames, reason='unicode filenames not supported') + @pytest.mark.skipif(not os.path.supports_unicode_filenames, + reason='unicode filenames not supported') def test_unicode_filename(self): # GH #209 name = (b'Besa_Kavaj\xc3\xab.feather').decode('utf-8') http://git-wip-us.apache.org/repos/asf/arrow/blob/bb0a7588/python/pyarrow/tests/test_ipc.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index b91a8e9..b2b90d4 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -38,7 +38,7 @@ class MessagingTest(object): return io.BytesIO() def _get_source(self): - return pa.BufferReader(self.sink.getvalue()) + return self.sink.getvalue() def write_batches(self): nrows = 5 @@ -79,7 +79,7 @@ class TestFile(MessagingTest, unittest.TestCase): def test_simple_roundtrip(self): batches = self.write_batches() - file_contents = self._get_source() + file_contents = pa.BufferReader(self._get_source()) reader = pa.open_file(file_contents) @@ -93,7 +93,7 @@ class TestFile(MessagingTest, unittest.TestCase): def test_read_all(self): batches = self.write_batches() - file_contents = self._get_source() + file_contents = pa.BufferReader(self._get_source()) reader = pa.open_file(file_contents) @@ -114,7 +114,7 @@ class TestStream(MessagingTest, unittest.TestCase): def test_simple_roundtrip(self): batches = self.write_batches() - file_contents = self._get_source() + file_contents = pa.BufferReader(self._get_source()) reader = pa.open_stream(file_contents) assert reader.schema.equals(batches[0].schema) @@ -131,7 +131,7 @@ class TestStream(MessagingTest, unittest.TestCase): def test_read_all(self): batches = self.write_batches() - file_contents = self._get_source() + file_contents = pa.BufferReader(self._get_source()) reader = pa.open_stream(file_contents) result = reader.read_all() @@ -139,6 +139,55 @@ class TestStream(MessagingTest, unittest.TestCase): assert result.equals(expected) +class TestMessageReader(MessagingTest, unittest.TestCase): + + def _get_example_messages(self): + batches = self.write_batches() + file_contents = self._get_source() + buf_reader = pa.BufferReader(file_contents) + reader = pa.MessageReader.open_stream(buf_reader) + return batches, list(reader) + + def _get_writer(self, sink, schema): + return pa.RecordBatchStreamWriter(sink, schema) + + def test_ctors_no_segfault(self): + with pytest.raises(TypeError): + repr(pa.Message()) + + with pytest.raises(TypeError): + repr(pa.MessageReader()) + + def test_message_reader(self): + _, messages = self._get_example_messages() + + assert len(messages) == 6 + assert messages[0].type == 'schema' + for msg in messages[1:]: + assert msg.type == 'record batch' + + def test_serialize_read_message(self): + _, messages = self._get_example_messages() + + msg = messages[0] + buf = msg.serialize() + + restored = pa.read_message(buf) + restored2 = pa.read_message(pa.BufferReader(buf)) + restored3 = pa.read_message(buf.to_pybytes()) + + assert msg.equals(restored) + assert msg.equals(restored2) + assert msg.equals(restored3) + + def test_read_record_batch(self): + batches, messages = self._get_example_messages() + + for batch, message in zip(batches, messages[1:]): + read_batch = pa.read_record_batch(message, batch.schema) + assert read_batch.equals(batch) + + class TestSocket(MessagingTest, unittest.TestCase): class StreamReaderServer(threading.Thread): http://git-wip-us.apache.org/repos/asf/arrow/blob/bb0a7588/python/pyarrow/tests/test_parquet.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/tests/test_parquet.py b/python/pyarrow/tests/test_parquet.py index 94a0e38..f606a7f 100644 --- a/python/pyarrow/tests/test_parquet.py +++ b/python/pyarrow/tests/test_parquet.py @@ -453,13 +453,15 @@ def test_date_time_types(): table = pa.Table.from_arrays([a1, a2, a3, a4, a5, a6], ['date32', 'date64', 'timestamp[us]', - 'time32[s]', 'time64[us]', 'time32_from64[s]']) + 'time32[s]', 'time64[us]', + 'time32_from64[s]']) # date64 as date32 # time32[s] to time32[ms] expected = pa.Table.from_arrays([a1, a1, a3, a4, a5, ex_a6], ['date32', 'date64', 'timestamp[us]', - 'time32[s]', 'time64[us]', 'time32_from64[s]']) + 'time32[s]', 'time64[us]', + 'time32_from64[s]']) _check_roundtrip(table, expected=expected, version='2.0') @@ -848,6 +850,7 @@ def test_read_multiple_files(tmpdir): with pytest.raises(ValueError): read_multiple_files(mixed_paths) + @parquet def test_multiindex_duplicate_values(tmpdir): num_rows = 3 http://git-wip-us.apache.org/repos/asf/arrow/blob/bb0a7588/python/pyarrow/tests/test_table.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index 3198941..c2aeda9 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -75,6 +75,10 @@ def test_recordbatch_basics(): ('c1', [-10, -5, 0, 5, 10]) ]) + with pytest.raises(IndexError): + # bounds checking + batch[2] + def test_recordbatch_slice(): data = [ http://git-wip-us.apache.org/repos/asf/arrow/blob/bb0a7588/python/pyarrow/tests/test_tensor.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/tests/test_tensor.py b/python/pyarrow/tests/test_tensor.py index a83f6f2..c495834 100644 --- a/python/pyarrow/tests/test_tensor.py +++ b/python/pyarrow/tests/test_tensor.py @@ -115,6 +115,7 @@ def test_tensor_size(): tensor = pa.Tensor.from_numpy(data) assert pa.get_tensor_size(tensor) > (data.size * 8) + def test_read_tensor(tmpdir): # Create and write tensor tensor data = np.random.randn(10, 4)