This is an automated email from the ASF dual-hosted git repository. apitrou pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push: new 6c4118b ARROW-4562: [C++] Avoid copies when serializing Flight data 6c4118b is described below commit 6c4118b274cadf044b2d0581401a018f0a438205 Author: Antoine Pitrou <anto...@python.org> AuthorDate: Wed Feb 20 10:54:12 2019 +0100 ARROW-4562: [C++] Avoid copies when serializing Flight data Also massage the Flight headers to avoid unnecessary includes and declarations. Author: Antoine Pitrou <anto...@python.org> Closes #3705 from pitrou/ARROW-4562-no-copy-grpc-serialization and squashes the following commits: b24194baa <Antoine Pitrou> ARROW-4562: Avoid copies when serializing Flight data --- cpp/src/arrow/flight/client.cc | 2 +- cpp/src/arrow/flight/customize_protobuf.h | 39 ++++-- cpp/src/arrow/flight/flight-test.cc | 6 - cpp/src/arrow/flight/serialization-internal.cc | 169 ++++++++++++++----------- cpp/src/arrow/flight/serialization-internal.h | 62 +-------- cpp/src/arrow/flight/server.cc | 4 +- 6 files changed, 129 insertions(+), 153 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 9925c25..2563126 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -68,7 +68,7 @@ class FlightStreamReader : public RecordBatchReader { std::shared_ptr<Schema> schema() const override { return schema_; } Status ReadNext(std::shared_ptr<RecordBatch>* out) override { - FlightData data; + internal::FlightData data; if (stream_finished_) { *out = nullptr; diff --git a/cpp/src/arrow/flight/customize_protobuf.h b/cpp/src/arrow/flight/customize_protobuf.h index 1f69253..9f5dfab 100644 --- a/cpp/src/arrow/flight/customize_protobuf.h +++ b/cpp/src/arrow/flight/customize_protobuf.h @@ -31,12 +31,33 @@ #include <grpcpp/impl/codegen/proto_utils.h> +namespace grpc { + +class ByteBuffer; + +} // namespace grpc + namespace arrow { namespace flight { -struct FlightData; struct FlightPayload; +namespace internal { + +struct FlightData; + +// Those two functions are defined in serialization-internal.cc + +// Write FlightData to a grpc::ByteBuffer without extra copying +grpc::Status FlightDataSerialize(const FlightPayload& msg, grpc::ByteBuffer* out, + bool* own_buffer); + +// Read internal::FlightData from grpc::ByteBuffer containing FlightData +// protobuf without copying +grpc::Status FlightDataDeserialize(grpc::ByteBuffer* buffer, FlightData* out); + +} // namespace internal + namespace protocol { class FlightData; @@ -47,15 +68,6 @@ class FlightData; namespace grpc { -using arrow::flight::FlightData; -using arrow::flight::FlightPayload; - -class ByteBuffer; -class Status; - -Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, bool* own_buffer); -Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out); - // This class provides a protobuf serializer. It translates between protobuf // objects and grpc_byte_buffers. More information about SerializationTraits can // be found in include/grpcpp/impl/codegen/serialization_traits.h. @@ -81,12 +93,13 @@ class SerializationTraits<T, typename std::enable_if<std::is_same< public: static Status Serialize(const grpc::protobuf::Message& msg, ByteBuffer* bb, bool* own_buffer) { - return FlightDataSerialize(*reinterpret_cast<const FlightPayload*>(&msg), bb, - own_buffer); + return arrow::flight::internal::FlightDataSerialize( + *reinterpret_cast<const arrow::flight::FlightPayload*>(&msg), bb, own_buffer); } static Status Deserialize(ByteBuffer* buffer, grpc::protobuf::Message* msg) { - return FlightDataDeserialize(buffer, reinterpret_cast<FlightData*>(msg)); + return arrow::flight::internal::FlightDataDeserialize( + buffer, reinterpret_cast<arrow::flight::internal::FlightData*>(msg)); } }; diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc index 9268aec..d1ab0aa 100644 --- a/cpp/src/arrow/flight/flight-test.cc +++ b/cpp/src/arrow/flight/flight-test.cc @@ -15,12 +15,6 @@ // specific language governing permissions and limitations // under the License. -#ifndef _WIN32 -#include <sys/stat.h> -#include <sys/wait.h> -#include <unistd.h> -#endif - #include <chrono> #include <cstdint> #include <cstdio> diff --git a/cpp/src/arrow/flight/serialization-internal.cc b/cpp/src/arrow/flight/serialization-internal.cc index 0c031e0..d80c0c7 100644 --- a/cpp/src/arrow/flight/serialization-internal.cc +++ b/cpp/src/arrow/flight/serialization-internal.cc @@ -17,33 +17,53 @@ #include "arrow/flight/serialization-internal.h" +#include <cstdint> +#include <limits> #include <string> +#include <vector> + +#include <google/protobuf/io/zero_copy_stream_impl_lite.h> +#include <google/protobuf/wire_format_lite.h> +#include <grpc/byte_buffer_reader.h> +#include <grpcpp/grpcpp.h> +#include <grpcpp/impl/codegen/proto_utils.h> #include "arrow/buffer.h" #include "arrow/flight/server.h" #include "arrow/ipc/writer.h" +#include "arrow/util/bit-util.h" +#include "arrow/util/logging.h" + +namespace pb = arrow::flight::protocol; + +static constexpr int64_t kInt32Max = std::numeric_limits<int32_t>::max(); namespace arrow { namespace flight { namespace internal { -bool ReadBytesZeroCopy(const std::shared_ptr<arrow::Buffer>& source_data, - CodedInputStream* input, std::shared_ptr<arrow::Buffer>* out) { +using arrow::ipc::internal::IpcPayload; + +using google::protobuf::internal::WireFormatLite; +using google::protobuf::io::ArrayOutputStream; +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::CodedOutputStream; + +using grpc::ByteBuffer; + +bool ReadBytesZeroCopy(const std::shared_ptr<Buffer>& source_data, + CodedInputStream* input, std::shared_ptr<Buffer>* out) { uint32_t length; if (!input->ReadVarint32(&length)) { return false; } - *out = arrow::SliceBuffer(source_data, input->CurrentPosition(), - static_cast<int64_t>(length)); + *out = SliceBuffer(source_data, input->CurrentPosition(), static_cast<int64_t>(length)); return input->Skip(static_cast<int>(length)); } -using google::protobuf::io::CodedInputStream; -using google::protobuf::io::CodedOutputStream; - // Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow // consumers with zero-copy -class GrpcBuffer : public arrow::MutableBuffer { +class GrpcBuffer : public MutableBuffer { public: GrpcBuffer(grpc_slice slice, bool incref) : MutableBuffer(GRPC_SLICE_START_PTR(slice), @@ -55,8 +75,7 @@ class GrpcBuffer : public arrow::MutableBuffer { grpc_slice_unref(slice_); } - static arrow::Status Wrap(grpc::ByteBuffer* cpp_buf, - std::shared_ptr<arrow::Buffer>* out) { + static Status Wrap(ByteBuffer* cpp_buf, std::shared_ptr<Buffer>* out) { // These types are guaranteed by static assertions in gRPC to have the same // in-memory representation @@ -80,7 +99,7 @@ class GrpcBuffer : public arrow::MutableBuffer { // us back a new slice with the refcount already incremented. grpc_byte_buffer_reader reader; if (!grpc_byte_buffer_reader_init(&reader, buffer)) { - return arrow::Status::IOError("Internal gRPC error reading from ByteBuffer"); + return Status::IOError("Internal gRPC error reading from ByteBuffer"); } grpc_slice slice = grpc_byte_buffer_reader_readall(&reader); grpc_byte_buffer_reader_destroy(&reader); @@ -89,37 +108,42 @@ class GrpcBuffer : public arrow::MutableBuffer { *out = std::make_shared<GrpcBuffer>(slice, false); } - return arrow::Status::OK(); + return Status::OK(); } private: grpc_slice slice_; }; -} // namespace internal -} // namespace flight -} // namespace arrow - -namespace grpc { +// Destructor callback for grpc::Slice +static void ReleaseBuffer(void* buf_ptr) { + delete reinterpret_cast<std::shared_ptr<Buffer>*>(buf_ptr); +} -using arrow::flight::FlightData; -using arrow::flight::internal::GrpcBuffer; -using arrow::flight::internal::ReadBytesZeroCopy; +// Initialize gRPC Slice from arrow Buffer +grpc::Slice SliceFromBuffer(const std::shared_ptr<Buffer>& buf) { + // Allocate persistent shared_ptr to control Buffer lifetime + auto ptr = new std::shared_ptr<Buffer>(buf); + grpc::Slice slice(const_cast<uint8_t*>(buf->data()), static_cast<size_t>(buf->size()), + &ReleaseBuffer, ptr); + // Make sure no copy was done (some grpc::Slice() constructors do an implicit memcpy) + DCHECK_EQ(slice.begin(), buf->data()); + return slice; +} -using google::protobuf::internal::WireFormatLite; -using google::protobuf::io::ArrayOutputStream; -using google::protobuf::io::CodedInputStream; -using google::protobuf::io::CodedOutputStream; +static const uint8_t kPaddingBytes[8] = {0, 0, 0, 0, 0, 0, 0, 0}; -Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, bool* own_buffer) { - size_t total_size = 0; +grpc::Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, + bool* own_buffer) { + size_t body_size = 0; + size_t header_size = 0; // Write the descriptor if present int32_t descriptor_size = 0; if (msg.descriptor != nullptr) { DCHECK_LT(msg.descriptor->size(), kInt32Max); descriptor_size = static_cast<int32_t>(msg.descriptor->size()); - total_size += 1 + WireFormatLite::LengthDelimitedSize(descriptor_size); + header_size += 1 + WireFormatLite::LengthDelimitedSize(descriptor_size); } const arrow::ipc::internal::IpcPayload& ipc_msg = msg.ipc_message; @@ -128,98 +152,94 @@ Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, bool* own_ const int32_t metadata_size = static_cast<int32_t>(ipc_msg.metadata->size()); // 1 byte for metadata tag - total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size); + header_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size); - int64_t body_size = 0; for (const auto& buffer : ipc_msg.body_buffers) { // Buffer may be null when the row length is zero, or when all // entries are invalid. if (!buffer) continue; - body_size += buffer->size(); - - const int64_t remainder = buffer->size() % 8; - if (remainder) { - body_size += 8 - remainder; - } + body_size += static_cast<size_t>(BitUtil::RoundUpToMultipleOf8(buffer->size())); } // 2 bytes for body tag // Only written when there are body buffers if (ipc_msg.body_length > 0) { - total_size += 2 + WireFormatLite::LengthDelimitedSize(static_cast<size_t>(body_size)); + // We write the body tag in the header but not the actual body data + header_size += 2 + WireFormatLite::LengthDelimitedSize(body_size) - body_size; } // TODO(wesm): messages over 2GB unlikely to be yet supported - if (total_size > kInt32Max) { + if (body_size > kInt32Max) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Cannot send record batches exceeding 2GB yet"); } - // Allocate slice, assign to output buffer - grpc::Slice slice(total_size); + // Allocate and initialize slices + std::vector<grpc::Slice> slices; + grpc::Slice header_slice(header_size); + slices.push_back(header_slice); // XXX(wesm): for debugging // std::cout << "Writing record batch with total size " << total_size << std::endl; - ArrayOutputStream writer(const_cast<uint8_t*>(slice.begin()), - static_cast<int>(slice.size())); - CodedOutputStream pb_stream(&writer); + ArrayOutputStream header_writer(const_cast<uint8_t*>(header_slice.begin()), + static_cast<int>(header_slice.size())); + CodedOutputStream header_stream(&header_writer); // Write descriptor if (msg.descriptor != nullptr) { WireFormatLite::WriteTag(pb::FlightData::kFlightDescriptorFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); - pb_stream.WriteVarint32(descriptor_size); - pb_stream.WriteRawMaybeAliased(msg.descriptor->data(), - static_cast<int>(msg.descriptor->size())); + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); + header_stream.WriteVarint32(descriptor_size); + header_stream.WriteRawMaybeAliased(msg.descriptor->data(), + static_cast<int>(msg.descriptor->size())); } // Write header WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); - pb_stream.WriteVarint32(metadata_size); - pb_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(), - static_cast<int>(ipc_msg.metadata->size())); + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); + header_stream.WriteVarint32(metadata_size); + header_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(), + static_cast<int>(ipc_msg.metadata->size())); - // Don't write tag if there are no body buffers + // Don't write body tag if there are no body buffers if (ipc_msg.body_length > 0) { - // Write body + // Write body tag WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &pb_stream); - pb_stream.WriteVarint32(static_cast<uint32_t>(body_size)); - - constexpr uint8_t kPaddingBytes[8] = {0}; + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); + header_stream.WriteVarint32(static_cast<uint32_t>(body_size)); + // Enqueue body buffers for writing, without copying for (const auto& buffer : ipc_msg.body_buffers) { // Buffer may be null when the row length is zero, or when all // entries are invalid. if (!buffer) continue; - pb_stream.WriteRawMaybeAliased(buffer->data(), static_cast<int>(buffer->size())); + slices.push_back(SliceFromBuffer(buffer)); // Write padding if not multiple of 8 - const int remainder = static_cast<int>(buffer->size() % 8); + const auto remainder = static_cast<int>( + BitUtil::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); if (remainder) { - pb_stream.WriteRawMaybeAliased(kPaddingBytes, 8 - remainder); + slices.push_back(grpc::Slice(kPaddingBytes, remainder)); } } } - DCHECK_EQ(static_cast<int>(total_size), pb_stream.ByteCount()); + DCHECK_EQ(static_cast<int>(header_size), header_stream.ByteCount()); - // Hand off the slice to the returned ByteBuffer - grpc::ByteBuffer tmp(&slice, 1); - out->Swap(&tmp); + // Hand off the slices to the returned ByteBuffer + *out = grpc::ByteBuffer(slices.data(), slices.size()); *own_buffer = true; return grpc::Status::OK; } // Read internal::FlightData from grpc::ByteBuffer containing FlightData // protobuf without copying -Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) { +grpc::Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) { if (!buffer) { - return Status(StatusCode::INTERNAL, "No payload"); + return grpc::Status(grpc::StatusCode::INTERNAL, "No payload"); } std::shared_ptr<arrow::Buffer> wrapped_buffer; @@ -240,15 +260,16 @@ Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) { pb::FlightDescriptor pb_descriptor; uint32_t length; if (!pb_stream.ReadVarint32(&length)) { - return Status(StatusCode::INTERNAL, - "Unable to parse length of FlightDescriptor"); + return grpc::Status(grpc::StatusCode::INTERNAL, + "Unable to parse length of FlightDescriptor"); } // Can't use ParseFromCodedStream as this reads the entire // rest of the stream into the descriptor command field. std::string buffer; pb_stream.ReadString(&buffer, length); if (!pb_descriptor.ParseFromString(buffer)) { - return Status(StatusCode::INTERNAL, "Unable to parse FlightDescriptor"); + return grpc::Status(grpc::StatusCode::INTERNAL, + "Unable to parse FlightDescriptor"); } arrow::flight::FlightDescriptor descriptor; GRPC_RETURN_NOT_OK( @@ -257,12 +278,14 @@ Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) { } break; case pb::FlightData::kDataHeaderFieldNumber: { if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { - return Status(StatusCode::INTERNAL, "Unable to read FlightData metadata"); + return grpc::Status(grpc::StatusCode::INTERNAL, + "Unable to read FlightData metadata"); } } break; case pb::FlightData::kDataBodyFieldNumber: { if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { - return Status(StatusCode::INTERNAL, "Unable to read FlightData body"); + return grpc::Status(grpc::StatusCode::INTERNAL, + "Unable to read FlightData body"); } } break; default: @@ -274,7 +297,9 @@ Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) { // TODO(wesm): Where and when should we verify that the FlightData is not // malformed or missing components? - return Status::OK; + return grpc::Status::OK; } -} // namespace grpc +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h index d8e7aad..4576290 100644 --- a/cpp/src/arrow/flight/serialization-internal.h +++ b/cpp/src/arrow/flight/serialization-internal.h @@ -23,32 +23,19 @@ // Enable gRPC customizations #include "arrow/flight/protocol-internal.h" // IWYU pragma: keep -#include <cstdint> -#include <limits> #include <memory> #include <google/protobuf/io/coded_stream.h> -#include <google/protobuf/io/zero_copy_stream.h> -#include <google/protobuf/io/zero_copy_stream_impl_lite.h> -#include <google/protobuf/wire_format_lite.h> -#include <grpcpp/grpcpp.h> -#include <grpcpp/impl/codegen/proto_utils.h> -#include "grpc/byte_buffer_reader.h" - -#include "arrow/ipc/writer.h" -#include "arrow/record_batch.h" -#include "arrow/status.h" -#include "arrow/util/logging.h" #include "arrow/flight/internal.h" #include "arrow/flight/types.h" -namespace pb = arrow::flight::protocol; +namespace arrow { -constexpr int64_t kInt32Max = std::numeric_limits<int32_t>::max(); +class Buffer; -namespace arrow { namespace flight { +namespace internal { /// Internal, not user-visible type used for memory-efficient reads from gRPC /// stream @@ -63,49 +50,6 @@ struct FlightData { std::shared_ptr<Buffer> body; }; -namespace internal { - -using google::protobuf::io::CodedInputStream; -using google::protobuf::io::CodedOutputStream; - -bool ReadBytesZeroCopy(const std::shared_ptr<arrow::Buffer>& source_data, - CodedInputStream* input, std::shared_ptr<arrow::Buffer>* out); - } // namespace internal } // namespace flight } // namespace arrow - -namespace grpc { - -using arrow::flight::FlightData; - -using google::protobuf::internal::WireFormatLite; -using google::protobuf::io::ArrayOutputStream; -using google::protobuf::io::CodedInputStream; -using google::protobuf::io::CodedOutputStream; - -// Helper to log status code, as gRPC doesn't expose why -// (de)serialization fails -inline Status FailSerialization(Status status) { - if (!status.ok()) { - ARROW_LOG(WARNING) << "Error deserializing Flight message: " - << status.error_message(); - } - return status; -} - -inline arrow::Status FailSerialization(arrow::Status status) { - if (!status.ok()) { - ARROW_LOG(WARNING) << "Error deserializing Flight message: " << status.ToString(); - } - return status; -} - -// Write FlightData to a grpc::ByteBuffer without extra copying -Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, bool* own_buffer); - -// Read internal::FlightData from grpc::ByteBuffer containing FlightData -// protobuf without copying -Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out); - -} // namespace grpc diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 0b95e53..fe9c1bb 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -73,7 +73,7 @@ class FlightMessageReaderImpl : public FlightMessageReader { return Status::OK(); } - FlightData data; + internal::FlightData data; // Pretend to be pb::FlightData and intercept in SerializationTraits if (reader_->Read(reinterpret_cast<pb::FlightData*>(&data))) { std::unique_ptr<ipc::Message> message; @@ -209,7 +209,7 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status DoPut(ServerContext* context, grpc::ServerReader<pb::FlightData>* reader, pb::PutResult* response) { // Get metadata - FlightData data; + internal::FlightData data; if (reader->Read(reinterpret_cast<pb::FlightData*>(&data))) { // Message only lives as long as data std::unique_ptr<ipc::Message> message;