This is an automated email from the ASF dual-hosted git repository. wesm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push: new f05369e ARROW-4796: [Flight/Python] Keep underlying Python object alive in FlightServerBase.do_get f05369e is described below commit f05369e6dc7cb5a36048cebb78d0d69b32a27b6f Author: David Li <david.m...@twosigma.com> AuthorDate: Tue Mar 12 09:15:09 2019 -0500 ARROW-4796: [Flight/Python] Keep underlying Python object alive in FlightServerBase.do_get Author: David Li <david.m...@twosigma.com> Closes #3834 from lihalite/arrow-4796 and squashes the following commits: 942b9a708 <David Li> Keep underlying Python object alive in FlightServerBase.do_get --- cpp/src/arrow/python/flight.cc | 13 +++++ cpp/src/arrow/python/flight.h | 14 +++++ python/pyarrow/_flight.pyx | 7 ++- python/pyarrow/includes/libarrow_flight.pxd | 8 +++ python/pyarrow/tests/test_flight.py | 80 +++++++++++++++++++++++++++++ 5 files changed, 121 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc index a8bae63..ec25d32 100644 --- a/cpp/src/arrow/python/flight.cc +++ b/cpp/src/arrow/python/flight.cc @@ -117,6 +117,19 @@ Status PyFlightResultStream::Next(std::unique_ptr<arrow::flight::Result>* result return CheckPyError(); } +PyFlightDataStream::PyFlightDataStream( + PyObject* data_source, std::unique_ptr<arrow::flight::FlightDataStream> stream) + : stream_(std::move(stream)) { + Py_INCREF(data_source); + data_source_.reset(data_source); +} + +std::shared_ptr<arrow::Schema> PyFlightDataStream::schema() { return stream_->schema(); } + +Status PyFlightDataStream::Next(arrow::flight::FlightPayload* payload) { + return stream_->Next(payload); +} + Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema, const arrow::flight::FlightDescriptor& descriptor, const std::vector<arrow::flight::FlightEndpoint>& endpoints, diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h index effd1a8..128784f 100644 --- a/cpp/src/arrow/python/flight.h +++ b/cpp/src/arrow/python/flight.h @@ -92,6 +92,20 @@ class ARROW_PYTHON_EXPORT PyFlightResultStream : public arrow::flight::ResultStr PyFlightResultStreamCallback callback_; }; +/// \brief A wrapper around a FlightDataStream that keeps alive a +/// Python object backing it. +class ARROW_PYTHON_EXPORT PyFlightDataStream : public arrow::flight::FlightDataStream { + public: + explicit PyFlightDataStream(PyObject* data_source, + std::unique_ptr<arrow::flight::FlightDataStream> stream); + std::shared_ptr<arrow::Schema> schema() override; + Status Next(arrow::flight::FlightPayload* payload) override; + + private: + OwnedRefNoGIL data_source_; + std::unique_ptr<arrow::flight::FlightDataStream> stream_; +}; + ARROW_PYTHON_EXPORT Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema, const arrow::flight::FlightDescriptor& descriptor, diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 532cf54..695513a 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -491,13 +491,18 @@ cdef void _do_put(void* self, cdef void _do_get(void* self, CTicket ticket, unique_ptr[CFlightDataStream]* stream) except *: """Callback for implementing Flight servers in Python.""" + cdef: + unique_ptr[CFlightDataStream] data_stream + py_ticket = Ticket(ticket.ticket) result = (<object> self).do_get(py_ticket) if not isinstance(result, FlightDataStream): raise TypeError("FlightServerBase.do_get must return " "a FlightDataStream") - stream[0] = unique_ptr[CFlightDataStream]( + data_stream = unique_ptr[CFlightDataStream]( (<FlightDataStream> result).to_stream()) + stream[0] = unique_ptr[CFlightDataStream]( + new CPyFlightDataStream(result, move(data_stream))) cdef void _do_action_result_next(void* self, diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 0271f33..153f725 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -156,6 +156,11 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: CPyFlightResultStream(object generator, function[cb_result_next] callback) + cdef cppclass CPyFlightDataStream\ + " arrow::py::flight::PyFlightDataStream"(CFlightDataStream): + CPyFlightDataStream(object data_source, + unique_ptr[CFlightDataStream] stream) + cdef CStatus CreateFlightInfo" arrow::py::flight::CreateFlightInfo"( shared_ptr[CSchema] schema, CFlightDescriptor& descriptor, @@ -163,3 +168,6 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: uint64_t total_records, uint64_t total_bytes, unique_ptr[CFlightInfo]* out) + +cdef extern from "<utility>" namespace "std": + unique_ptr[CFlightDataStream] move(unique_ptr[CFlightDataStream]) diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py new file mode 100644 index 0000000..d225f77 --- /dev/null +++ b/python/pyarrow/tests/test_flight.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import contextlib +import socket +import threading + +import pytest + +import pyarrow as pa + + +flight = pytest.importorskip("pyarrow.flight") + + +class ConstantFlightServer(flight.FlightServerBase): + """A Flight server that always returns the same data. + + See ARROW-4796: this server implementation will segfault if Flight + does not properly hold a reference to the Table object. + """ + + def do_get(self, ticket): + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + table = pa.Table.from_arrays(data, names=['a']) + return flight.RecordBatchStream(table) + + +@contextlib.contextmanager +def flight_server(server_base, *args, **kwargs): + """Spawn a Flight server on a free port, shutting it down when done.""" + # Find a free port + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + with contextlib.closing(sock) as sock: + sock.bind(('', 0)) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + port = sock.getsockname()[1] + + server_instance = server_base(*args, **kwargs) + + def _server_thread(): + server_instance.run(port) + + thread = threading.Thread(target=_server_thread, daemon=True) + thread.start() + + yield port + + server_instance.shutdown() + thread.join() + + +def test_flight_do_get(): + """Try a simple do_get call.""" + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + table = pa.Table.from_arrays(data, names=['a']) + + with flight_server(ConstantFlightServer) as server_port: + client = flight.FlightClient.connect('localhost', server_port) + data = client.do_get(flight.Ticket(b''), table.schema).read_all() + assert data.equals(table)