This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 2e61bcf  ARROW-4587: [C++] Fix segfaults around DoPut implementation
2e61bcf is described below

commit 2e61bcf0f67bbe051a2d252c8bd4f2e8df336ea3
Author: David Li <[email protected]>
AuthorDate: Tue Feb 19 16:32:18 2019 +0100

    ARROW-4587: [C++] Fix segfaults around DoPut implementation
    
    Also fixes exposing `arrow::ipc::DictionaryMemo` as `arrow::DictionaryMemo`.
    
    Author: David Li <[email protected]>
    
    Closes #3660 from lihalite/fix-flight-c++ and squashes the following 
commits:
    
    e670cb11c <David Li> Fix segfaults around DoPut implementation
---
 cpp/src/arrow/array.h                          |  2 +-
 cpp/src/arrow/array/builder_binary.h           |  2 +-
 cpp/src/arrow/flight/CMakeLists.txt            | 15 ++++---
 cpp/src/arrow/flight/client.cc                 | 40 +++++++++---------
 cpp/src/arrow/flight/client.h                  |  5 +--
 cpp/src/arrow/flight/customize_protobuf.h      | 19 +++------
 cpp/src/arrow/flight/flight-benchmark.cc       |  2 +-
 cpp/src/arrow/flight/internal.cc               |  6 ++-
 cpp/src/arrow/flight/internal.h                | 19 ++++++---
 cpp/src/arrow/flight/perf-server.cc            |  9 ++--
 cpp/src/arrow/flight/protocol-internal.h       |  6 +--
 cpp/src/arrow/flight/serialization-internal.cc | 58 +++++++++++++++++++++-----
 cpp/src/arrow/flight/serialization-internal.h  | 19 ++++-----
 cpp/src/arrow/flight/server.cc                 | 36 +++++++++-------
 cpp/src/arrow/flight/server.h                  | 25 ++---------
 cpp/src/arrow/flight/types.cc                  |  2 -
 cpp/src/arrow/flight/types.h                   |  8 ++++
 cpp/src/arrow/ipc/writer.h                     |  8 ++--
 18 files changed, 158 insertions(+), 123 deletions(-)

diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h
index 674bf7b..f8d451a 100644
--- a/cpp/src/arrow/array.h
+++ b/cpp/src/arrow/array.h
@@ -32,7 +32,7 @@
 #include "arrow/util/bit-util.h"
 #include "arrow/util/checked_cast.h"
 #include "arrow/util/macros.h"
-#include "arrow/util/string_view.h"
+#include "arrow/util/string_view.h"  // IWYU pragma: export
 #include "arrow/util/visibility.h"
 
 namespace arrow {
diff --git a/cpp/src/arrow/array/builder_binary.h 
b/cpp/src/arrow/array/builder_binary.h
index abd8387..67d579d 100644
--- a/cpp/src/arrow/array/builder_binary.h
+++ b/cpp/src/arrow/array/builder_binary.h
@@ -29,7 +29,7 @@
 #include "arrow/status.h"
 #include "arrow/type_traits.h"
 #include "arrow/util/macros.h"
-#include "arrow/util/string_view.h"
+#include "arrow/util/string_view.h"  // IWYU pragma: export
 
 namespace arrow {
 
diff --git a/cpp/src/arrow/flight/CMakeLists.txt 
b/cpp/src/arrow/flight/CMakeLists.txt
index 46329e4..f02bc21 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -127,6 +127,9 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS)
   if(ARROW_BUILD_TESTS)
     add_dependencies(arrow-flight-test flight-test-server)
   endif()
+
+  add_dependencies(arrow_flight flight-test-server 
flight-test-integration-client
+                   flight-test-integration-server)
 endif()
 
 if(ARROW_BUILD_BENCHMARKS)
@@ -139,21 +142,23 @@ if(ARROW_BUILD_BENCHMARKS)
                              "--cpp_out=${CMAKE_CURRENT_BINARY_DIR}" 
"perf.proto"
                      DEPENDS ${PROTO_DEPENDS})
 
-  add_executable(flight-perf-server perf-server.cc perf.pb.cc)
-  target_link_libraries(flight-perf-server
+  add_executable(arrow-flight-perf-server perf-server.cc perf.pb.cc)
+  target_link_libraries(arrow-flight-perf-server
                         arrow_flight_shared
                         arrow_flight_testing_shared
                         ${ARROW_FLIGHT_TEST_LINK_LIBS}
                         ${GFLAGS_LIBRARY}
                         ${GTEST_LIBRARY})
 
-  add_executable(flight-benchmark flight-benchmark.cc perf.pb.cc)
-  target_link_libraries(flight-benchmark
+  add_executable(arrow-flight-benchmark flight-benchmark.cc perf.pb.cc)
+  target_link_libraries(arrow-flight-benchmark
                         arrow_flight_static
                         arrow_testing_static
                         ${ARROW_FLIGHT_TEST_LINK_LIBS}
                         ${GFLAGS_LIBRARY}
                         ${GTEST_LIBRARY})
 
-  add_dependencies(flight-benchmark flight-perf-server)
+  add_dependencies(arrow-flight-benchmark arrow-flight-perf-server)
+
+  add_dependencies(arrow_flight arrow-flight-benchmark)
 endif(ARROW_BUILD_BENCHMARKS)
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 8520777..9925c25 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -16,14 +16,14 @@
 // under the License.
 
 #include "arrow/flight/client.h"
-#include "arrow/flight/protocol-internal.h"
+#include "arrow/flight/protocol-internal.h"  // IWYU pragma: keep
 
 #include <memory>
 #include <sstream>
 #include <string>
 #include <utility>
 
-#include "grpcpp/grpcpp.h"
+#include <grpcpp/grpcpp.h>
 
 #include "arrow/ipc/dictionary.h"
 #include "arrow/ipc/metadata-internal.h"
@@ -36,12 +36,14 @@
 
 #include "arrow/flight/internal.h"
 #include "arrow/flight/serialization-internal.h"
-
-using arrow::ipc::internal::IpcPayload;
+#include "arrow/flight/types.h"
 
 namespace pb = arrow::flight::protocol;
 
 namespace arrow {
+
+class MemoryPool;
+
 namespace flight {
 
 struct ClientRpc {
@@ -106,8 +108,6 @@ class FlightStreamReader : public RecordBatchReader {
   std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream_;
 };
 
-class FlightClient;
-
 /// \brief A RecordBatchWriter implementation that writes to a Flight
 /// DoPut stream.
 class FlightPutWriter::FlightPutWriterImpl : public ipc::RecordBatchWriter {
@@ -119,8 +119,9 @@ class FlightPutWriter::FlightPutWriterImpl : public 
ipc::RecordBatchWriter {
       : rpc_(std::move(rpc)), descriptor_(descriptor), schema_(schema), 
pool_(pool) {}
 
   Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) 
override {
-    IpcPayload payload;
-    RETURN_NOT_OK(ipc::internal::GetRecordBatchPayload(batch, pool_, 
&payload));
+    FlightPayload payload;
+    RETURN_NOT_OK(
+        ipc::internal::GetRecordBatchPayload(batch, pool_, 
&payload.ipc_message));
 
     if (!writer_->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
                         grpc::WriteOptions())) {
@@ -296,19 +297,18 @@ class FlightClient::FlightClientImpl {
         stub_->DoPut(&out->rpc_->context, &out->response));
 
     // First write the descriptor and schema to the stream.
-    pb::FlightData descriptor_message;
-    RETURN_NOT_OK(
-        internal::ToProto(descriptor, 
descriptor_message.mutable_flight_descriptor()));
-
-    std::shared_ptr<Buffer> header_buf;
-    RETURN_NOT_OK(Buffer::FromString("", &header_buf));
+    FlightPayload payload;
     ipc::DictionaryMemo dictionary_memo;
-    RETURN_NOT_OK(ipc::SerializeSchema(*schema, out->pool_, &header_buf));
-    RETURN_NOT_OK(
-        ipc::internal::WriteSchemaMessage(*schema, &dictionary_memo, 
&header_buf));
-    descriptor_message.set_data_header(header_buf->ToString());
-
-    if (!write_stream->Write(descriptor_message, grpc::WriteOptions())) {
+    RETURN_NOT_OK(ipc::internal::GetSchemaPayload(*schema, out->pool_, 
&dictionary_memo,
+                                                  &payload.ipc_message));
+    pb::FlightDescriptor pb_descr;
+    RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descr));
+    std::string str_descr;
+    pb_descr.SerializeToString(&str_descr);
+    RETURN_NOT_OK(Buffer::FromString(str_descr, &payload.descriptor));
+
+    if (!write_stream->Write(*reinterpret_cast<const 
pb::FlightData*>(&payload),
+                             grpc::WriteOptions())) {
       std::stringstream ss;
       ss << "Could not write descriptor and schema to stream: "
          << rpc->context.debug_error_string();
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 334158d..e88b73f 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -28,18 +28,17 @@
 #include "arrow/status.h"
 #include "arrow/util/visibility.h"
 
-#include "arrow/flight/types.h"
+#include "arrow/flight/types.h"  // IWYU pragma: keep
 
 namespace arrow {
 
+class MemoryPool;
 class RecordBatch;
 class RecordBatchReader;
 class Schema;
 
 namespace flight {
 
-class FlightPutWriter;
-
 /// \brief Client class for Arrow Flight RPC services (gRPC-based).
 /// API experimental for now
 class ARROW_EXPORT FlightClient {
diff --git a/cpp/src/arrow/flight/customize_protobuf.h 
b/cpp/src/arrow/flight/customize_protobuf.h
index fd2e086..1f69253 100644
--- a/cpp/src/arrow/flight/customize_protobuf.h
+++ b/cpp/src/arrow/flight/customize_protobuf.h
@@ -20,7 +20,7 @@
 #include <limits>
 #include <memory>
 
-#include "grpcpp/impl/codegen/config_protobuf.h"
+#include <grpcpp/impl/codegen/config_protobuf.h>
 
 // It is necessary to undefined this macro so that the protobuf
 // SerializationTraits specialization is not declared in proto_utils.h. We've
@@ -29,20 +29,13 @@
 // for our faster serialization-deserialization path
 #undef GRPC_OPEN_SOURCE_PROTO
 
-#include "grpcpp/impl/codegen/proto_utils.h"
+#include <grpcpp/impl/codegen/proto_utils.h>
 
 namespace arrow {
-namespace ipc {
-namespace internal {
-
-struct IpcPayload;
-
-}  // namespace internal
-}  // namespace ipc
-
 namespace flight {
 
 struct FlightData;
+struct FlightPayload;
 
 namespace protocol {
 
@@ -55,12 +48,12 @@ class FlightData;
 namespace grpc {
 
 using arrow::flight::FlightData;
-using arrow::ipc::internal::IpcPayload;
+using arrow::flight::FlightPayload;
 
 class ByteBuffer;
 class Status;
 
-Status FlightDataSerialize(const IpcPayload& msg, ByteBuffer* out, bool* 
own_buffer);
+Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, bool* 
own_buffer);
 Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out);
 
 // This class provides a protobuf serializer. It translates between protobuf
@@ -88,7 +81,7 @@ class SerializationTraits<T, typename 
std::enable_if<std::is_same<
  public:
   static Status Serialize(const grpc::protobuf::Message& msg, ByteBuffer* bb,
                           bool* own_buffer) {
-    return FlightDataSerialize(*reinterpret_cast<const IpcPayload*>(&msg), bb,
+    return FlightDataSerialize(*reinterpret_cast<const FlightPayload*>(&msg), 
bb,
                                own_buffer);
   }
 
diff --git a/cpp/src/arrow/flight/flight-benchmark.cc 
b/cpp/src/arrow/flight/flight-benchmark.cc
index adb775c..0f3eb6b 100644
--- a/cpp/src/arrow/flight/flight-benchmark.cc
+++ b/cpp/src/arrow/flight/flight-benchmark.cc
@@ -183,7 +183,7 @@ int main(int argc, char** argv) {
   gflags::ParseCommandLineFlags(&argc, &argv, true);
 
   const int port = 31337;
-  arrow::flight::TestServer server("flight-perf-server", port);
+  arrow::flight::TestServer server("arrow-flight-perf-server", port);
   server.Start();
 
   arrow::Status s = arrow::flight::RunPerformanceTest(port);
diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc
index 21268e8..1085e11 100644
--- a/cpp/src/arrow/flight/internal.cc
+++ b/cpp/src/arrow/flight/internal.cc
@@ -16,18 +16,20 @@
 // under the License.
 
 #include "arrow/flight/internal.h"
+#include "arrow/flight/protocol-internal.h"
 
-#include "arrow/flight/customize_protobuf.h"
-
+#include <cstddef>
 #include <memory>
 #include <string>
 #include <utility>
 
 #include <grpcpp/grpcpp.h>
 
+#include "arrow/buffer.h"
 #include "arrow/io/memory.h"
 #include "arrow/ipc/reader.h"
 #include "arrow/ipc/writer.h"
+#include "arrow/memory_pool.h"
 #include "arrow/status.h"
 #include "arrow/util/logging.h"
 
diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h
index 15c3d71..05adb6c 100644
--- a/cpp/src/arrow/flight/internal.h
+++ b/cpp/src/arrow/flight/internal.h
@@ -20,14 +20,15 @@
 #include <memory>
 #include <string>
 
-#include <grpcpp/grpcpp.h>
-
-#include "arrow/buffer.h"
-#include "arrow/ipc/writer.h"
+#include "arrow/flight/protocol-internal.h"  // IWYU pragma: keep
+#include "arrow/flight/types.h"
 #include "arrow/util/macros.h"
 
-#include "arrow/flight/protocol-internal.h"
-#include "arrow/flight/types.h"
+namespace grpc {
+
+class Status;
+
+}  // namespace grpc
 
 namespace arrow {
 
@@ -36,6 +37,12 @@ class Status;
 
 namespace pb = arrow::flight::protocol;
 
+namespace ipc {
+
+class Message;
+
+}  // namespace ipc
+
 namespace flight {
 
 #define GRPC_RETURN_NOT_OK(s)                             \
diff --git a/cpp/src/arrow/flight/perf-server.cc 
b/cpp/src/arrow/flight/perf-server.cc
index b470283..bee7060 100644
--- a/cpp/src/arrow/flight/perf-server.cc
+++ b/cpp/src/arrow/flight/perf-server.cc
@@ -41,8 +41,6 @@ DEFINE_int32(port, 31337, "Server port to listen on");
 namespace perf = arrow::flight::perf;
 namespace proto = arrow::flight::protocol;
 
-using IpcPayload = arrow::ipc::internal::IpcPayload;
-
 namespace arrow {
 namespace flight {
 
@@ -73,10 +71,10 @@ class PerfDataStream : public FlightDataStream {
 
   std::shared_ptr<Schema> schema() override { return schema_; }
 
-  Status Next(IpcPayload* payload) override {
+  Status Next(FlightPayload* payload) override {
     if (records_sent_ >= total_records_) {
       // Signal that iteration is over
-      payload->metadata = nullptr;
+      payload->ipc_message.metadata = nullptr;
       return Status::OK();
     }
 
@@ -98,7 +96,8 @@ class PerfDataStream : public FlightDataStream {
     } else {
       records_sent_ += batch_length_;
     }
-    return ipc::internal::GetRecordBatchPayload(*batch, default_memory_pool(), 
payload);
+    return ipc::internal::GetRecordBatchPayload(*batch, default_memory_pool(),
+                                                &payload->ipc_message);
   }
 
  private:
diff --git a/cpp/src/arrow/flight/protocol-internal.h 
b/cpp/src/arrow/flight/protocol-internal.h
index d3ba77f..2e8dd32 100644
--- a/cpp/src/arrow/flight/protocol-internal.h
+++ b/cpp/src/arrow/flight/protocol-internal.h
@@ -17,7 +17,7 @@
 #pragma once
 
 // Need to include this first to get our gRPC customizations
-#include "arrow/flight/customize_protobuf.h"
+#include "arrow/flight/customize_protobuf.h"  // IWYU pragma: export
 
-#include "arrow/flight/Flight.grpc.pb.h"
-#include "arrow/flight/Flight.pb.h"
+#include "arrow/flight/Flight.grpc.pb.h"  // IWYU pragma: export
+#include "arrow/flight/Flight.pb.h"       // IWYU pragma: export
diff --git a/cpp/src/arrow/flight/serialization-internal.cc 
b/cpp/src/arrow/flight/serialization-internal.cc
index 67b2155..0c031e0 100644
--- a/cpp/src/arrow/flight/serialization-internal.cc
+++ b/cpp/src/arrow/flight/serialization-internal.cc
@@ -17,6 +17,12 @@
 
 #include "arrow/flight/serialization-internal.h"
 
+#include <string>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/server.h"
+#include "arrow/ipc/writer.h"
+
 namespace arrow {
 namespace flight {
 namespace internal {
@@ -105,17 +111,27 @@ using google::protobuf::io::ArrayOutputStream;
 using google::protobuf::io::CodedInputStream;
 using google::protobuf::io::CodedOutputStream;
 
-Status FlightDataSerialize(const IpcPayload& msg, ByteBuffer* out, bool* 
own_buffer) {
+Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, bool* 
own_buffer) {
   size_t total_size = 0;
 
-  DCHECK_LT(msg.metadata->size(), kInt32Max);
-  const int32_t metadata_size = static_cast<int32_t>(msg.metadata->size());
+  // Write the descriptor if present
+  int32_t descriptor_size = 0;
+  if (msg.descriptor != nullptr) {
+    DCHECK_LT(msg.descriptor->size(), kInt32Max);
+    descriptor_size = static_cast<int32_t>(msg.descriptor->size());
+    total_size += 1 + WireFormatLite::LengthDelimitedSize(descriptor_size);
+  }
+
+  const arrow::ipc::internal::IpcPayload& ipc_msg = msg.ipc_message;
+
+  DCHECK_LT(ipc_msg.metadata->size(), kInt32Max);
+  const int32_t metadata_size = static_cast<int32_t>(ipc_msg.metadata->size());
 
   // 1 byte for metadata tag
   total_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size);
 
   int64_t body_size = 0;
-  for (const auto& buffer : msg.body_buffers) {
+  for (const auto& buffer : ipc_msg.body_buffers) {
     // Buffer may be null when the row length is zero, or when all
     // entries are invalid.
     if (!buffer) continue;
@@ -130,7 +146,7 @@ Status FlightDataSerialize(const IpcPayload& msg, 
ByteBuffer* out, bool* own_buf
 
   // 2 bytes for body tag
   // Only written when there are body buffers
-  if (msg.body_length > 0) {
+  if (ipc_msg.body_length > 0) {
     total_size += 2 + 
WireFormatLite::LengthDelimitedSize(static_cast<size_t>(body_size));
   }
 
@@ -150,15 +166,24 @@ Status FlightDataSerialize(const IpcPayload& msg, 
ByteBuffer* out, bool* own_buf
                            static_cast<int>(slice.size()));
   CodedOutputStream pb_stream(&writer);
 
+  // Write descriptor
+  if (msg.descriptor != nullptr) {
+    WireFormatLite::WriteTag(pb::FlightData::kFlightDescriptorFieldNumber,
+                             WireFormatLite::WIRETYPE_LENGTH_DELIMITED, 
&pb_stream);
+    pb_stream.WriteVarint32(descriptor_size);
+    pb_stream.WriteRawMaybeAliased(msg.descriptor->data(),
+                                   static_cast<int>(msg.descriptor->size()));
+  }
+
   // Write header
   WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber,
                            WireFormatLite::WIRETYPE_LENGTH_DELIMITED, 
&pb_stream);
   pb_stream.WriteVarint32(metadata_size);
-  pb_stream.WriteRawMaybeAliased(msg.metadata->data(),
-                                 static_cast<int>(msg.metadata->size()));
+  pb_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(),
+                                 static_cast<int>(ipc_msg.metadata->size()));
 
   // Don't write tag if there are no body buffers
-  if (msg.body_length > 0) {
+  if (ipc_msg.body_length > 0) {
     // Write body
     WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
                              WireFormatLite::WIRETYPE_LENGTH_DELIMITED, 
&pb_stream);
@@ -166,7 +191,7 @@ Status FlightDataSerialize(const IpcPayload& msg, 
ByteBuffer* out, bool* own_buf
 
     constexpr uint8_t kPaddingBytes[8] = {0};
 
-    for (const auto& buffer : msg.body_buffers) {
+    for (const auto& buffer : ipc_msg.body_buffers) {
       // Buffer may be null when the row length is zero, or when all
       // entries are invalid.
       if (!buffer) continue;
@@ -213,9 +238,22 @@ Status FlightDataDeserialize(ByteBuffer* buffer, 
FlightData* out) {
     switch (field_number) {
       case pb::FlightData::kFlightDescriptorFieldNumber: {
         pb::FlightDescriptor pb_descriptor;
-        if (!pb_descriptor.ParseFromCodedStream(&pb_stream)) {
+        uint32_t length;
+        if (!pb_stream.ReadVarint32(&length)) {
+          return Status(StatusCode::INTERNAL,
+                        "Unable to parse length of FlightDescriptor");
+        }
+        // Can't use ParseFromCodedStream as this reads the entire
+        // rest of the stream into the descriptor command field.
+        std::string buffer;
+        pb_stream.ReadString(&buffer, length);
+        if (!pb_descriptor.ParseFromString(buffer)) {
           return Status(StatusCode::INTERNAL, "Unable to parse 
FlightDescriptor");
         }
+        arrow::flight::FlightDescriptor descriptor;
+        GRPC_RETURN_NOT_OK(
+            arrow::flight::internal::FromProto(pb_descriptor, &descriptor));
+        out->descriptor.reset(new arrow::flight::FlightDescriptor(descriptor));
       } break;
       case pb::FlightData::kDataHeaderFieldNumber: {
         if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) {
diff --git a/cpp/src/arrow/flight/serialization-internal.h 
b/cpp/src/arrow/flight/serialization-internal.h
index 19c8592..d8e7aad 100644
--- a/cpp/src/arrow/flight/serialization-internal.h
+++ b/cpp/src/arrow/flight/serialization-internal.h
@@ -21,18 +21,19 @@
 #pragma once
 
 // Enable gRPC customizations
-#include "arrow/flight/protocol-internal.h"
+#include "arrow/flight/protocol-internal.h"  // IWYU pragma: keep
 
+#include <cstdint>
 #include <limits>
 #include <memory>
 
-#include "google/protobuf/io/coded_stream.h"
-#include "google/protobuf/io/zero_copy_stream.h"
-#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
-#include "google/protobuf/wire_format_lite.h"
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
+#include <google/protobuf/wire_format_lite.h>
+#include <grpcpp/grpcpp.h>
+#include <grpcpp/impl/codegen/proto_utils.h>
 #include "grpc/byte_buffer_reader.h"
-#include "grpcpp/grpcpp.h"
-#include "grpcpp/impl/codegen/proto_utils.h"
 
 #include "arrow/ipc/writer.h"
 #include "arrow/record_batch.h"
@@ -44,8 +45,6 @@
 
 namespace pb = arrow::flight::protocol;
 
-using arrow::ipc::internal::IpcPayload;
-
 constexpr int64_t kInt32Max = std::numeric_limits<int32_t>::max();
 
 namespace arrow {
@@ -103,7 +102,7 @@ inline arrow::Status FailSerialization(arrow::Status 
status) {
 }
 
 // Write FlightData to a grpc::ByteBuffer without extra copying
-Status FlightDataSerialize(const IpcPayload& msg, ByteBuffer* out, bool* 
own_buffer);
+Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, bool* 
own_buffer);
 
 // Read internal::FlightData from grpc::ByteBuffer containing FlightData
 // protobuf without copying
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index 2fef93d..0b95e53 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -21,11 +21,14 @@
 #include <cstdint>
 #include <memory>
 #include <string>
+#include <utility>
 
-#include "grpcpp/grpcpp.h"
+#include <grpcpp/grpcpp.h>
 
+#include "arrow/ipc/dictionary.h"
 #include "arrow/ipc/reader.h"
 #include "arrow/ipc/writer.h"
+#include "arrow/memory_pool.h"
 #include "arrow/record_batch.h"
 #include "arrow/status.h"
 #include "arrow/util/logging.h"
@@ -37,8 +40,6 @@
 using FlightService = arrow::flight::protocol::FlightService;
 using ServerContext = grpc::ServerContext;
 
-using arrow::ipc::internal::IpcPayload;
-
 template <typename T>
 using ServerWriter = grpc::ServerWriter<T>;
 
@@ -180,20 +181,21 @@ class FlightServiceImpl : public FlightService::Service {
     GRPC_RETURN_NOT_OK(server_->DoGet(ticket, &data_stream));
 
     // Write the schema as the first message in the stream
-    IpcPayload schema_payload;
+    FlightPayload schema_payload;
     MemoryPool* pool = default_memory_pool();
     ipc::DictionaryMemo dictionary_memo;
     GRPC_RETURN_NOT_OK(ipc::internal::GetSchemaPayload(
-        *data_stream->schema(), pool, &dictionary_memo, &schema_payload));
+        *data_stream->schema(), pool, &dictionary_memo, 
&schema_payload.ipc_message));
 
-    // Pretend to be pb::FlightData, we cast back to IpcPayload in 
SerializationTraits
+    // Pretend to be pb::FlightData, we cast back to FlightPayload in
+    // SerializationTraits
     writer->Write(*reinterpret_cast<const pb::FlightData*>(&schema_payload),
                   grpc::WriteOptions());
 
     while (true) {
-      IpcPayload payload;
+      FlightPayload payload;
       GRPC_RETURN_NOT_OK(data_stream->Next(&payload));
-      if (payload.metadata == nullptr ||
+      if (payload.ipc_message.metadata == nullptr ||
           !writer->Write(*reinterpret_cast<const pb::FlightData*>(&payload),
                          grpc::WriteOptions())) {
         // No more messages to write, or connection terminated for some other
@@ -207,22 +209,24 @@ class FlightServiceImpl : public FlightService::Service {
   grpc::Status DoPut(ServerContext* context, 
grpc::ServerReader<pb::FlightData>* reader,
                      pb::PutResult* response) {
     // Get metadata
-    pb::FlightData data;
-    if (reader->Read(&data)) {
-      FlightDescriptor descriptor;
+    FlightData data;
+    if (reader->Read(reinterpret_cast<pb::FlightData*>(&data))) {
       // Message only lives as long as data
       std::unique_ptr<ipc::Message> message;
-      GRPC_RETURN_NOT_OK(internal::FromProto(data, &descriptor, &message));
+      GRPC_RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, 
&message));
 
       if (!message || message->type() != ipc::Message::Type::SCHEMA) {
         return internal::ToGrpcStatus(
             Status(StatusCode::Invalid, "DoPut must start with 
schema/descriptor"));
+      } else if (data.descriptor == nullptr) {
+        return internal::ToGrpcStatus(
+            Status(StatusCode::Invalid, "DoPut must start with non-null 
descriptor"));
       } else {
         std::shared_ptr<Schema> schema;
         GRPC_RETURN_NOT_OK(ipc::ReadSchema(*message, &schema));
 
         auto message_reader = std::unique_ptr<FlightMessageReader>(
-            new FlightMessageReaderImpl(descriptor, schema, reader));
+            new FlightMessageReaderImpl(*data.descriptor.get(), schema, 
reader));
         return 
internal::ToGrpcStatus(server_->DoPut(std::move(message_reader)));
       }
     } else {
@@ -333,16 +337,16 @@ RecordBatchStream::RecordBatchStream(const 
std::shared_ptr<RecordBatchReader>& r
 
 std::shared_ptr<Schema> RecordBatchStream::schema() { return 
reader_->schema(); }
 
-Status RecordBatchStream::Next(IpcPayload* payload) {
+Status RecordBatchStream::Next(FlightPayload* payload) {
   std::shared_ptr<RecordBatch> batch;
   RETURN_NOT_OK(reader_->ReadNext(&batch));
 
   if (!batch) {
     // Signal that iteration is over
-    payload->metadata = nullptr;
+    payload->ipc_message.metadata = nullptr;
     return Status::OK();
   } else {
-    return ipc::internal::GetRecordBatchPayload(*batch, pool_, payload);
+    return ipc::internal::GetRecordBatchPayload(*batch, pool_, 
&payload->ipc_message);
   }
 }
 
diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h
index b2e8b02..407c4bc 100644
--- a/cpp/src/arrow/flight/server.h
+++ b/cpp/src/arrow/flight/server.h
@@ -21,36 +21,19 @@
 #pragma once
 
 #include <memory>
-#include <string>
-#include <utility>
 #include <vector>
 
 #include "arrow/util/visibility.h"
 
-#include "arrow/flight/types.h"
-#include "arrow/ipc/dictionary.h"
+#include "arrow/flight/types.h"  // IWYU pragma: keep
 #include "arrow/record_batch.h"
 
 namespace arrow {
 
 class MemoryPool;
-class RecordBatchReader;
+class Schema;
 class Status;
 
-namespace ipc {
-namespace internal {
-
-struct IpcPayload;
-
-}  // namespace internal
-}  // namespace ipc
-
-namespace io {
-
-class OutputStream;
-
-}  // namespace io
-
 namespace flight {
 
 /// \brief Interface that produces a sequence of IPC payloads to be sent in
@@ -64,7 +47,7 @@ class ARROW_EXPORT FlightDataStream {
 
   // When the stream is completed, the last payload written will have null
   // metadata
-  virtual Status Next(ipc::internal::IpcPayload* payload) = 0;
+  virtual Status Next(FlightPayload* payload) = 0;
 };
 
 /// \brief A basic implementation of FlightDataStream that will provide
@@ -75,7 +58,7 @@ class ARROW_EXPORT RecordBatchStream : public 
FlightDataStream {
   explicit RecordBatchStream(const std::shared_ptr<RecordBatchReader>& reader);
 
   std::shared_ptr<Schema> schema() override;
-  Status Next(ipc::internal::IpcPayload* payload) override;
+  Status Next(FlightPayload* payload) override;
 
  private:
   MemoryPool* pool_;
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index aba93ad..fb8f8c6 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -18,8 +18,6 @@
 #include "arrow/flight/types.h"
 
 #include <memory>
-#include <sstream>
-#include <string>
 #include <utility>
 
 #include "arrow/io/memory.h"
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index 6db2655..ba0ab85 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -19,12 +19,14 @@
 
 #pragma once
 
+#include <cstddef>
 #include <cstdint>
 #include <memory>
 #include <string>
 #include <utility>
 #include <vector>
 
+#include "arrow/ipc/writer.h"
 #include "arrow/util/visibility.h"
 
 namespace arrow {
@@ -111,6 +113,12 @@ struct FlightEndpoint {
   std::vector<Location> locations;
 };
 
+/// \brief Staging data structure for messages about to be put on the wire
+struct FlightPayload {
+  std::shared_ptr<Buffer> descriptor;
+  ipc::internal::IpcPayload ipc_message;
+};
+
 /// \brief The access coordinates for retireval of a dataset, returned by
 /// GetFlightInfo
 class FlightInfo {
diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h
index 5b099d5..50872e9 100644
--- a/cpp/src/arrow/ipc/writer.h
+++ b/cpp/src/arrow/ipc/writer.h
@@ -30,7 +30,6 @@
 namespace arrow {
 
 class Buffer;
-class DictionaryMemo;
 class MemoryPool;
 class RecordBatch;
 class Schema;
@@ -47,6 +46,8 @@ class OutputStream;
 
 namespace ipc {
 
+class DictionaryMemo;
+
 /// \class RecordBatchWriter
 /// \brief Abstract interface for writing a stream of record batches
 class ARROW_EXPORT RecordBatchWriter {
@@ -298,9 +299,8 @@ namespace internal {
 
 // These internal APIs may change without warning or deprecation
 
-// Intermediate data structure with metadata header plus zero or more buffers
-// for the message body. This data can either be written out directly as an
-// encapsulated IPC message or used with Flight RPCs
+// Intermediate data structure with metadata header, and zero or more buffers
+// for the message body.
 struct IpcPayload {
   Message::Type type;
   std::shared_ptr<Buffer> metadata;

Reply via email to