This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 2e61bcf ARROW-4587: [C++] Fix segfaults around DoPut implementation
2e61bcf is described below
commit 2e61bcf0f67bbe051a2d252c8bd4f2e8df336ea3
Author: David Li <[email protected]>
AuthorDate: Tue Feb 19 16:32:18 2019 +0100
ARROW-4587: [C++] Fix segfaults around DoPut implementation
Also fixes exposing `arrow::ipc::DictionaryMemo` as `arrow::DictionaryMemo`.
Author: David Li <[email protected]>
Closes #3660 from lihalite/fix-flight-c++ and squashes the following
commits:
e670cb11c <David Li> Fix segfaults around DoPut implementation
---
cpp/src/arrow/array.h | 2 +-
cpp/src/arrow/array/builder_binary.h | 2 +-
cpp/src/arrow/flight/CMakeLists.txt | 15 ++++---
cpp/src/arrow/flight/client.cc | 40 +++++++++---------
cpp/src/arrow/flight/client.h | 5 +--
cpp/src/arrow/flight/customize_protobuf.h | 19 +++------
cpp/src/arrow/flight/flight-benchmark.cc | 2 +-
cpp/src/arrow/flight/internal.cc | 6 ++-
cpp/src/arrow/flight/internal.h | 19 ++++++---
cpp/src/arrow/flight/perf-server.cc | 9 ++--
cpp/src/arrow/flight/protocol-internal.h | 6 +--
cpp/src/arrow/flight/serialization-internal.cc | 58 +++++++++++++++++++++-----
cpp/src/arrow/flight/serialization-internal.h | 19 ++++-----
cpp/src/arrow/flight/server.cc | 36 +++++++++-------
cpp/src/arrow/flight/server.h | 25 ++---------
cpp/src/arrow/flight/types.cc | 2 -
cpp/src/arrow/flight/types.h | 8 ++++
cpp/src/arrow/ipc/writer.h | 8 ++--
18 files changed, 158 insertions(+), 123 deletions(-)
diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h
index 674bf7b..f8d451a 100644
--- a/cpp/src/arrow/array.h
+++ b/cpp/src/arrow/array.h
@@ -32,7 +32,7 @@
#include "arrow/util/bit-util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/macros.h"
-#include "arrow/util/string_view.h"
+#include "arrow/util/string_view.h" // IWYU pragma: export
#include "arrow/util/visibility.h"
namespace arrow {
diff --git a/cpp/src/arrow/array/builder_binary.h
b/cpp/src/arrow/array/builder_binary.h
index abd8387..67d579d 100644
--- a/cpp/src/arrow/array/builder_binary.h
+++ b/cpp/src/arrow/array/builder_binary.h
@@ -29,7 +29,7 @@
#include "arrow/status.h"
#include "arrow/type_traits.h"
#include "arrow/util/macros.h"
-#include "arrow/util/string_view.h"
+#include "arrow/util/string_view.h" // IWYU pragma: export
namespace arrow {
diff --git a/cpp/src/arrow/flight/CMakeLists.txt
b/cpp/src/arrow/flight/CMakeLists.txt
index 46329e4..f02bc21 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -127,6 +127,9 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS)
if(ARROW_BUILD_TESTS)
add_dependencies(arrow-flight-test flight-test-server)
endif()
+
+ add_dependencies(arrow_flight flight-test-server
flight-test-integration-client
+ flight-test-integration-server)
endif()
if(ARROW_BUILD_BENCHMARKS)
@@ -139,21 +142,23 @@ if(ARROW_BUILD_BENCHMARKS)
"--cpp_out=${CMAKE_CURRENT_BINARY_DIR}"
"perf.proto"
DEPENDS ${PROTO_DEPENDS})
- add_executable(flight-perf-server perf-server.cc perf.pb.cc)
- target_link_libraries(flight-perf-server
+ add_executable(arrow-flight-perf-server perf-server.cc perf.pb.cc)
+ target_link_libraries(arrow-flight-perf-server
arrow_flight_shared
arrow_flight_testing_shared
${ARROW_FLIGHT_TEST_LINK_LIBS}
${GFLAGS_LIBRARY}
${GTEST_LIBRARY})
- add_executable(flight-benchmark flight-benchmark.cc perf.pb.cc)
- target_link_libraries(flight-benchmark
+ add_executable(arrow-flight-benchmark flight-benchmark.cc perf.pb.cc)
+ target_link_libraries(arrow-flight-benchmark
arrow_flight_static
arrow_testing_static
${ARROW_FLIGHT_TEST_LINK_LIBS}
${GFLAGS_LIBRARY}
${GTEST_LIBRARY})
- add_dependencies(flight-benchmark flight-perf-server)
+ add_dependencies(arrow-flight-benchmark arrow-flight-perf-server)
+
+ add_dependencies(arrow_flight arrow-flight-benchmark)
endif(ARROW_BUILD_BENCHMARKS)
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 8520777..9925c25 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -16,14 +16,14 @@
// under the License.
#include "arrow/flight/client.h"
-#include "arrow/flight/protocol-internal.h"
+#include "arrow/flight/protocol-internal.h" // IWYU pragma: keep
#include <memory>
#include <sstream>
#include <string>
#include <utility>
-#include "grpcpp/grpcpp.h"
+#include <grpcpp/grpcpp.h>
#include "arrow/ipc/dictionary.h"
#include "arrow/ipc/metadata-internal.h"
@@ -36,12 +36,14 @@
#include "arrow/flight/internal.h"
#include "arrow/flight/serialization-internal.h"
-
-using arrow::ipc::internal::IpcPayload;
+#include "arrow/flight/types.h"
namespace pb = arrow::flight::protocol;
namespace arrow {
+
+class MemoryPool;
+
namespace flight {
struct ClientRpc {
@@ -106,8 +108,6 @@ class FlightStreamReader : public RecordBatchReader {
std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream_;
};
-class FlightClient;
-
/// \brief A RecordBatchWriter implementation that writes to a Flight
/// DoPut stream.
class FlightPutWriter::FlightPutWriterImpl : public ipc::RecordBatchWriter {
@@ -119,8 +119,9 @@ class FlightPutWriter::FlightPutWriterImpl : public
ipc::RecordBatchWriter {
: rpc_(std::move(rpc)), descriptor_(descriptor), schema_(schema),
pool_(pool) {}
Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false)
override {
- IpcPayload payload;
- RETURN_NOT_OK(ipc::internal::GetRecordBatchPayload(batch, pool_,
&payload));
+ FlightPayload payload;
+ RETURN_NOT_OK(
+ ipc::internal::GetRecordBatchPayload(batch, pool_,
&payload.ipc_message));
if (!writer_->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
grpc::WriteOptions())) {
@@ -296,19 +297,18 @@ class FlightClient::FlightClientImpl {
stub_->DoPut(&out->rpc_->context, &out->response));
// First write the descriptor and schema to the stream.
- pb::FlightData descriptor_message;
- RETURN_NOT_OK(
- internal::ToProto(descriptor,
descriptor_message.mutable_flight_descriptor()));
-
- std::shared_ptr<Buffer> header_buf;
- RETURN_NOT_OK(Buffer::FromString("", &header_buf));
+ FlightPayload payload;
ipc::DictionaryMemo dictionary_memo;
- RETURN_NOT_OK(ipc::SerializeSchema(*schema, out->pool_, &header_buf));
- RETURN_NOT_OK(
- ipc::internal::WriteSchemaMessage(*schema, &dictionary_memo,
&header_buf));
- descriptor_message.set_data_header(header_buf->ToString());
-
- if (!write_stream->Write(descriptor_message, grpc::WriteOptions())) {
+ RETURN_NOT_OK(ipc::internal::GetSchemaPayload(*schema, out->pool_,
&dictionary_memo,
+ &payload.ipc_message));
+ pb::FlightDescriptor pb_descr;
+ RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descr));
+ std::string str_descr;
+ pb_descr.SerializeToString(&str_descr);
+ RETURN_NOT_OK(Buffer::FromString(str_descr, &payload.descriptor));
+
+ if (!write_stream->Write(*reinterpret_cast<const
pb::FlightData*>(&payload),
+ grpc::WriteOptions())) {
std::stringstream ss;
ss << "Could not write descriptor and schema to stream: "
<< rpc->context.debug_error_string();
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 334158d..e88b73f 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -28,18 +28,17 @@
#include "arrow/status.h"
#include "arrow/util/visibility.h"
-#include "arrow/flight/types.h"
+#include "arrow/flight/types.h" // IWYU pragma: keep
namespace arrow {
+class MemoryPool;
class RecordBatch;
class RecordBatchReader;
class Schema;
namespace flight {
-class FlightPutWriter;
-
/// \brief Client class for Arrow Flight RPC services (gRPC-based).
/// API experimental for now
class ARROW_EXPORT FlightClient {
diff --git a/cpp/src/arrow/flight/customize_protobuf.h
b/cpp/src/arrow/flight/customize_protobuf.h
index fd2e086..1f69253 100644
--- a/cpp/src/arrow/flight/customize_protobuf.h
+++ b/cpp/src/arrow/flight/customize_protobuf.h
@@ -20,7 +20,7 @@
#include <limits>
#include <memory>
-#include "grpcpp/impl/codegen/config_protobuf.h"
+#include <grpcpp/impl/codegen/config_protobuf.h>
// It is necessary to undefined this macro so that the protobuf
// SerializationTraits specialization is not declared in proto_utils.h. We've
@@ -29,20 +29,13 @@
// for our faster serialization-deserialization path
#undef GRPC_OPEN_SOURCE_PROTO
-#include "grpcpp/impl/codegen/proto_utils.h"
+#include <grpcpp/impl/codegen/proto_utils.h>
namespace arrow {
-namespace ipc {
-namespace internal {
-
-struct IpcPayload;
-
-} // namespace internal
-} // namespace ipc
-
namespace flight {
struct FlightData;
+struct FlightPayload;
namespace protocol {
@@ -55,12 +48,12 @@ class FlightData;
namespace grpc {
using arrow::flight::FlightData;
-using arrow::ipc::internal::IpcPayload;
+using arrow::flight::FlightPayload;
class ByteBuffer;
class Status;
-Status FlightDataSerialize(const IpcPayload& msg, ByteBuffer* out, bool*
own_buffer);
+Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, bool*
own_buffer);
Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out);
// This class provides a protobuf serializer. It translates between protobuf
@@ -88,7 +81,7 @@ class SerializationTraits<T, typename
std::enable_if<std::is_same<
public:
static Status Serialize(const grpc::protobuf::Message& msg, ByteBuffer* bb,
bool* own_buffer) {
- return FlightDataSerialize(*reinterpret_cast<const IpcPayload*>(&msg), bb,
+ return FlightDataSerialize(*reinterpret_cast<const FlightPayload*>(&msg),
bb,
own_buffer);
}
diff --git a/cpp/src/arrow/flight/flight-benchmark.cc
b/cpp/src/arrow/flight/flight-benchmark.cc
index adb775c..0f3eb6b 100644
--- a/cpp/src/arrow/flight/flight-benchmark.cc
+++ b/cpp/src/arrow/flight/flight-benchmark.cc
@@ -183,7 +183,7 @@ int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
const int port = 31337;
- arrow::flight::TestServer server("flight-perf-server", port);
+ arrow::flight::TestServer server("arrow-flight-perf-server", port);
server.Start();
arrow::Status s = arrow::flight::RunPerformanceTest(port);
diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc
index 21268e8..1085e11 100644
--- a/cpp/src/arrow/flight/internal.cc
+++ b/cpp/src/arrow/flight/internal.cc
@@ -16,18 +16,20 @@
// under the License.
#include "arrow/flight/internal.h"
+#include "arrow/flight/protocol-internal.h"
-#include "arrow/flight/customize_protobuf.h"
-
+#include <cstddef>
#include <memory>
#include <string>
#include <utility>
#include <grpcpp/grpcpp.h>
+#include "arrow/buffer.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
+#include "arrow/memory_pool.h"
#include "arrow/status.h"
#include "arrow/util/logging.h"
diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h
index 15c3d71..05adb6c 100644
--- a/cpp/src/arrow/flight/internal.h
+++ b/cpp/src/arrow/flight/internal.h
@@ -20,14 +20,15 @@
#include <memory>
#include <string>
-#include <grpcpp/grpcpp.h>
-
-#include "arrow/buffer.h"
-#include "arrow/ipc/writer.h"
+#include "arrow/flight/protocol-internal.h" // IWYU pragma: keep
+#include "arrow/flight/types.h"
#include "arrow/util/macros.h"
-#include "arrow/flight/protocol-internal.h"
-#include "arrow/flight/types.h"
+namespace grpc {
+
+class Status;
+
+} // namespace grpc
namespace arrow {
@@ -36,6 +37,12 @@ class Status;
namespace pb = arrow::flight::protocol;
+namespace ipc {
+
+class Message;
+
+} // namespace ipc
+
namespace flight {
#define GRPC_RETURN_NOT_OK(s) \
diff --git a/cpp/src/arrow/flight/perf-server.cc
b/cpp/src/arrow/flight/perf-server.cc
index b470283..bee7060 100644
--- a/cpp/src/arrow/flight/perf-server.cc
+++ b/cpp/src/arrow/flight/perf-server.cc
@@ -41,8 +41,6 @@ DEFINE_int32(port, 31337, "Server port to listen on");
namespace perf = arrow::flight::perf;
namespace proto = arrow::flight::protocol;
-using IpcPayload = arrow::ipc::internal::IpcPayload;
-
namespace arrow {
namespace flight {
@@ -73,10 +71,10 @@ class PerfDataStream : public FlightDataStream {
std::shared_ptr<Schema> schema() override { return schema_; }
- Status Next(IpcPayload* payload) override {
+ Status Next(FlightPayload* payload) override {
if (records_sent_ >= total_records_) {
// Signal that iteration is over
- payload->metadata = nullptr;
+ payload->ipc_message.metadata = nullptr;
return Status::OK();
}
@@ -98,7 +96,8 @@ class PerfDataStream : public FlightDataStream {
} else {
records_sent_ += batch_length_;
}
- return ipc::internal::GetRecordBatchPayload(*batch, default_memory_pool(),
payload);
+ return ipc::internal::GetRecordBatchPayload(*batch, default_memory_pool(),
+ &payload->ipc_message);
}
private:
diff --git a/cpp/src/arrow/flight/protocol-internal.h
b/cpp/src/arrow/flight/protocol-internal.h
index d3ba77f..2e8dd32 100644
--- a/cpp/src/arrow/flight/protocol-internal.h
+++ b/cpp/src/arrow/flight/protocol-internal.h
@@ -17,7 +17,7 @@
#pragma once
// Need to include this first to get our gRPC customizations
-#include "arrow/flight/customize_protobuf.h"
+#include "arrow/flight/customize_protobuf.h" // IWYU pragma: export
-#include "arrow/flight/Flight.grpc.pb.h"
-#include "arrow/flight/Flight.pb.h"
+#include "arrow/flight/Flight.grpc.pb.h" // IWYU pragma: export
+#include "arrow/flight/Flight.pb.h" // IWYU pragma: export
diff --git a/cpp/src/arrow/flight/serialization-internal.cc
b/cpp/src/arrow/flight/serialization-internal.cc
index 67b2155..0c031e0 100644
--- a/cpp/src/arrow/flight/serialization-internal.cc
+++ b/cpp/src/arrow/flight/serialization-internal.cc
@@ -17,6 +17,12 @@
#include "arrow/flight/serialization-internal.h"
+#include <string>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/server.h"
+#include "arrow/ipc/writer.h"
+
namespace arrow {
namespace flight {
namespace internal {
@@ -105,17 +111,27 @@ using google::protobuf::io::ArrayOutputStream;
using google::protobuf::io::CodedInputStream;
using google::protobuf::io::CodedOutputStream;
-Status FlightDataSerialize(const IpcPayload& msg, ByteBuffer* out, bool*
own_buffer) {
+Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, bool*
own_buffer) {
size_t total_size = 0;
- DCHECK_LT(msg.metadata->size(), kInt32Max);
- const int32_t metadata_size = static_cast<int32_t>(msg.metadata->size());
+ // Write the descriptor if present
+ int32_t descriptor_size = 0;
+ if (msg.descriptor != nullptr) {
+ DCHECK_LT(msg.descriptor->size(), kInt32Max);
+ descriptor_size = static_cast<int32_t>(msg.descriptor->size());
+ total_size += 1 + WireFormatLite::LengthDelimitedSize(descriptor_size);
+ }
+
+ const arrow::ipc::internal::IpcPayload& ipc_msg = msg.ipc_message;
+
+ DCHECK_LT(ipc_msg.metadata->size(), kInt32Max);
+ const int32_t metadata_size = static_cast<int32_t>(ipc_msg.metadata->size());
// 1 byte for metadata tag
total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size);
int64_t body_size = 0;
- for (const auto& buffer : msg.body_buffers) {
+ for (const auto& buffer : ipc_msg.body_buffers) {
// Buffer may be null when the row length is zero, or when all
// entries are invalid.
if (!buffer) continue;
@@ -130,7 +146,7 @@ Status FlightDataSerialize(const IpcPayload& msg,
ByteBuffer* out, bool* own_buf
// 2 bytes for body tag
// Only written when there are body buffers
- if (msg.body_length > 0) {
+ if (ipc_msg.body_length > 0) {
total_size += 2 +
WireFormatLite::LengthDelimitedSize(static_cast<size_t>(body_size));
}
@@ -150,15 +166,24 @@ Status FlightDataSerialize(const IpcPayload& msg,
ByteBuffer* out, bool* own_buf
static_cast<int>(slice.size()));
CodedOutputStream pb_stream(&writer);
+ // Write descriptor
+ if (msg.descriptor != nullptr) {
+ WireFormatLite::WriteTag(pb::FlightData::kFlightDescriptorFieldNumber,
+ WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&pb_stream);
+ pb_stream.WriteVarint32(descriptor_size);
+ pb_stream.WriteRawMaybeAliased(msg.descriptor->data(),
+ static_cast<int>(msg.descriptor->size()));
+ }
+
// Write header
WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&pb_stream);
pb_stream.WriteVarint32(metadata_size);
- pb_stream.WriteRawMaybeAliased(msg.metadata->data(),
- static_cast<int>(msg.metadata->size()));
+ pb_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(),
+ static_cast<int>(ipc_msg.metadata->size()));
// Don't write tag if there are no body buffers
- if (msg.body_length > 0) {
+ if (ipc_msg.body_length > 0) {
// Write body
WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&pb_stream);
@@ -166,7 +191,7 @@ Status FlightDataSerialize(const IpcPayload& msg,
ByteBuffer* out, bool* own_buf
constexpr uint8_t kPaddingBytes[8] = {0};
- for (const auto& buffer : msg.body_buffers) {
+ for (const auto& buffer : ipc_msg.body_buffers) {
// Buffer may be null when the row length is zero, or when all
// entries are invalid.
if (!buffer) continue;
@@ -213,9 +238,22 @@ Status FlightDataDeserialize(ByteBuffer* buffer,
FlightData* out) {
switch (field_number) {
case pb::FlightData::kFlightDescriptorFieldNumber: {
pb::FlightDescriptor pb_descriptor;
- if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) {
+ uint32_t length;
+ if (!pb_stream.ReadVarint32(&length)) {
+ return Status(StatusCode::INTERNAL,
+ "Unable to parse length of FlightDescriptor");
+ }
+ // Can't use ParseFromCodedStream as this reads the entire
+ // rest of the stream into the descriptor command field.
+ std::string buffer;
+ pb_stream.ReadString(&buffer, length);
+ if (!pb_descriptor.ParseFromString(buffer)) {
return Status(StatusCode::INTERNAL, "Unable to parse
FlightDescriptor");
}
+ arrow::flight::FlightDescriptor descriptor;
+ GRPC_RETURN_NOT_OK(
+ arrow::flight::internal::FromProto(pb_descriptor, &descriptor));
+ out->descriptor.reset(new arrow::flight::FlightDescriptor(descriptor));
} break;
case pb::FlightData::kDataHeaderFieldNumber: {
if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) {
diff --git a/cpp/src/arrow/flight/serialization-internal.h
b/cpp/src/arrow/flight/serialization-internal.h
index 19c8592..d8e7aad 100644
--- a/cpp/src/arrow/flight/serialization-internal.h
+++ b/cpp/src/arrow/flight/serialization-internal.h
@@ -21,18 +21,19 @@
#pragma once
// Enable gRPC customizations
-#include "arrow/flight/protocol-internal.h"
+#include "arrow/flight/protocol-internal.h" // IWYU pragma: keep
+#include <cstdint>
#include <limits>
#include <memory>
-#include "google/protobuf/io/coded_stream.h"
-#include "google/protobuf/io/zero_copy_stream.h"
-#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
-#include "google/protobuf/wire_format_lite.h"
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
+#include <google/protobuf/wire_format_lite.h>
+#include <grpcpp/grpcpp.h>
+#include <grpcpp/impl/codegen/proto_utils.h>
#include "grpc/byte_buffer_reader.h"
-#include "grpcpp/grpcpp.h"
-#include "grpcpp/impl/codegen/proto_utils.h"
#include "arrow/ipc/writer.h"
#include "arrow/record_batch.h"
@@ -44,8 +45,6 @@
namespace pb = arrow::flight::protocol;
-using arrow::ipc::internal::IpcPayload;
-
constexpr int64_t kInt32Max = std::numeric_limits<int32_t>::max();
namespace arrow {
@@ -103,7 +102,7 @@ inline arrow::Status FailSerialization(arrow::Status
status) {
}
// Write FlightData to a grpc::ByteBuffer without extra copying
-Status FlightDataSerialize(const IpcPayload& msg, ByteBuffer* out, bool*
own_buffer);
+Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, bool*
own_buffer);
// Read internal::FlightData from grpc::ByteBuffer containing FlightData
// protobuf without copying
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index 2fef93d..0b95e53 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -21,11 +21,14 @@
#include <cstdint>
#include <memory>
#include <string>
+#include <utility>
-#include "grpcpp/grpcpp.h"
+#include <grpcpp/grpcpp.h>
+#include "arrow/ipc/dictionary.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
+#include "arrow/memory_pool.h"
#include "arrow/record_batch.h"
#include "arrow/status.h"
#include "arrow/util/logging.h"
@@ -37,8 +40,6 @@
using FlightService = arrow::flight::protocol::FlightService;
using ServerContext = grpc::ServerContext;
-using arrow::ipc::internal::IpcPayload;
-
template <typename T>
using ServerWriter = grpc::ServerWriter<T>;
@@ -180,20 +181,21 @@ class FlightServiceImpl : public FlightService::Service {
GRPC_RETURN_NOT_OK(server_->DoGet(ticket, &data_stream));
// Write the schema as the first message in the stream
- IpcPayload schema_payload;
+ FlightPayload schema_payload;
MemoryPool* pool = default_memory_pool();
ipc::DictionaryMemo dictionary_memo;
GRPC_RETURN_NOT_OK(ipc::internal::GetSchemaPayload(
- *data_stream->schema(), pool, &dictionary_memo, &schema_payload));
+ *data_stream->schema(), pool, &dictionary_memo,
&schema_payload.ipc_message));
- // Pretend to be pb::FlightData, we cast back to IpcPayload in
SerializationTraits
+ // Pretend to be pb::FlightData, we cast back to FlightPayload in
+ // SerializationTraits
writer->Write(*reinterpret_cast<const pb::FlightData*>(&schema_payload),
grpc::WriteOptions());
while (true) {
- IpcPayload payload;
+ FlightPayload payload;
GRPC_RETURN_NOT_OK(data_stream->Next(&payload));
- if (payload.metadata == nullptr ||
+ if (payload.ipc_message.metadata == nullptr ||
!writer->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
grpc::WriteOptions())) {
// No more messages to write, or connection terminated for some other
@@ -207,22 +209,24 @@ class FlightServiceImpl : public FlightService::Service {
grpc::Status DoPut(ServerContext* context,
grpc::ServerReader<pb::FlightData>* reader,
pb::PutResult* response) {
// Get metadata
- pb::FlightData data;
- if (reader->Read(&data)) {
- FlightDescriptor descriptor;
+ FlightData data;
+ if (reader->Read(reinterpret_cast<pb::FlightData*>(&data))) {
// Message only lives as long as data
std::unique_ptr<ipc::Message> message;
- GRPC_RETURN_NOT_OK(internal::FromProto(data, &descriptor, &message));
+ GRPC_RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body,
&message));
if (!message || message->type() != ipc::Message::Type::SCHEMA) {
return internal::ToGrpcStatus(
Status(StatusCode::Invalid, "DoPut must start with
schema/descriptor"));
+ } else if (data.descriptor == nullptr) {
+ return internal::ToGrpcStatus(
+ Status(StatusCode::Invalid, "DoPut must start with non-null
descriptor"));
} else {
std::shared_ptr<Schema> schema;
GRPC_RETURN_NOT_OK(ipc::ReadSchema(*message, &schema));
auto message_reader = std::unique_ptr<FlightMessageReader>(
- new FlightMessageReaderImpl(descriptor, schema, reader));
+ new FlightMessageReaderImpl(*data.descriptor.get(), schema,
reader));
return
internal::ToGrpcStatus(server_->DoPut(std::move(message_reader)));
}
} else {
@@ -333,16 +337,16 @@ RecordBatchStream::RecordBatchStream(const
std::shared_ptr<RecordBatchReader>& r
std::shared_ptr<Schema> RecordBatchStream::schema() { return
reader_->schema(); }
-Status RecordBatchStream::Next(IpcPayload* payload) {
+Status RecordBatchStream::Next(FlightPayload* payload) {
std::shared_ptr<RecordBatch> batch;
RETURN_NOT_OK(reader_->ReadNext(&batch));
if (!batch) {
// Signal that iteration is over
- payload->metadata = nullptr;
+ payload->ipc_message.metadata = nullptr;
return Status::OK();
} else {
- return ipc::internal::GetRecordBatchPayload(*batch, pool_, payload);
+ return ipc::internal::GetRecordBatchPayload(*batch, pool_,
&payload->ipc_message);
}
}
diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h
index b2e8b02..407c4bc 100644
--- a/cpp/src/arrow/flight/server.h
+++ b/cpp/src/arrow/flight/server.h
@@ -21,36 +21,19 @@
#pragma once
#include <memory>
-#include <string>
-#include <utility>
#include <vector>
#include "arrow/util/visibility.h"
-#include "arrow/flight/types.h"
-#include "arrow/ipc/dictionary.h"
+#include "arrow/flight/types.h" // IWYU pragma: keep
#include "arrow/record_batch.h"
namespace arrow {
class MemoryPool;
-class RecordBatchReader;
+class Schema;
class Status;
-namespace ipc {
-namespace internal {
-
-struct IpcPayload;
-
-} // namespace internal
-} // namespace ipc
-
-namespace io {
-
-class OutputStream;
-
-} // namespace io
-
namespace flight {
/// \brief Interface that produces a sequence of IPC payloads to be sent in
@@ -64,7 +47,7 @@ class ARROW_EXPORT FlightDataStream {
// When the stream is completed, the last payload written will have null
// metadata
- virtual Status Next(ipc::internal::IpcPayload* payload) = 0;
+ virtual Status Next(FlightPayload* payload) = 0;
};
/// \brief A basic implementation of FlightDataStream that will provide
@@ -75,7 +58,7 @@ class ARROW_EXPORT RecordBatchStream : public
FlightDataStream {
explicit RecordBatchStream(const std::shared_ptr<RecordBatchReader>& reader);
std::shared_ptr<Schema> schema() override;
- Status Next(ipc::internal::IpcPayload* payload) override;
+ Status Next(FlightPayload* payload) override;
private:
MemoryPool* pool_;
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index aba93ad..fb8f8c6 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -18,8 +18,6 @@
#include "arrow/flight/types.h"
#include <memory>
-#include <sstream>
-#include <string>
#include <utility>
#include "arrow/io/memory.h"
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index 6db2655..ba0ab85 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -19,12 +19,14 @@
#pragma once
+#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
+#include "arrow/ipc/writer.h"
#include "arrow/util/visibility.h"
namespace arrow {
@@ -111,6 +113,12 @@ struct FlightEndpoint {
std::vector<Location> locations;
};
+/// \brief Staging data structure for messages about to be put on the wire
+struct FlightPayload {
+ std::shared_ptr<Buffer> descriptor;
+ ipc::internal::IpcPayload ipc_message;
+};
+
/// \brief The access coordinates for retireval of a dataset, returned by
/// GetFlightInfo
class FlightInfo {
diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h
index 5b099d5..50872e9 100644
--- a/cpp/src/arrow/ipc/writer.h
+++ b/cpp/src/arrow/ipc/writer.h
@@ -30,7 +30,6 @@
namespace arrow {
class Buffer;
-class DictionaryMemo;
class MemoryPool;
class RecordBatch;
class Schema;
@@ -47,6 +46,8 @@ class OutputStream;
namespace ipc {
+class DictionaryMemo;
+
/// \class RecordBatchWriter
/// \brief Abstract interface for writing a stream of record batches
class ARROW_EXPORT RecordBatchWriter {
@@ -298,9 +299,8 @@ namespace internal {
// These internal APIs may change without warning or deprecation
-// Intermediate data structure with metadata header plus zero or more buffers
-// for the message body. This data can either be written out directly as an
-// encapsulated IPC message or used with Flight RPCs
+// Intermediate data structure with metadata header, and zero or more buffers
+// for the message body.
struct IpcPayload {
Message::Type type;
std::shared_ptr<Buffer> metadata;