lidavidm commented on a change in pull request #12465:
URL: https://github.com/apache/arrow/pull/12465#discussion_r817152998



##########
File path: cpp/src/arrow/flight/client.cc
##########
@@ -596,198 +427,256 @@ class GrpcStreamReader : public FlightStreamReader {
     return ReadAll(table, stop_token_);
   }
   using FlightStreamReader::ReadAll;
-  void Cancel() override { rpc_->context.TryCancel(); }
+  void Cancel() override { stream_->TryCancel(); }
 
  private:
-  std::unique_lock<std::mutex> TakeGuard() {
-    return read_mutex_ ? std::unique_lock<std::mutex>(*read_mutex_)
-                       : std::unique_lock<std::mutex>();
-  }
-
   Status OverrideWithServerError(Status&& st) {
     if (st.ok()) {
       return std::move(st);
     }
     return stream_->Finish(std::move(st));
   }
 
-  friend class GrpcIpcMessageReader<Reader>;
-  std::shared_ptr<ClientRpc> rpc_;
-  std::shared_ptr<MemoryManager> memory_manager_;
-  // Guard reads with a lock to prevent Finish()/Close() from being
-  // called on the writer while the reader has a pending
-  // read. Nullable, as DoGet() doesn't need this.
-  std::shared_ptr<std::mutex> read_mutex_;
+  std::shared_ptr<internal::ClientDataStream> stream_;
   ipc::IpcReadOptions options_;
   StopToken stop_token_;
-  std::shared_ptr<FinishableStream<Reader, internal::FlightData>> stream_;
-  std::shared_ptr<internal::PeekableFlightDataReader<std::shared_ptr<Reader>>>
-      peekable_reader_;
+  std::shared_ptr<internal::PeekableFlightDataReader> peekable_reader_;
   std::shared_ptr<ipc::RecordBatchReader> batch_reader_;
   std::shared_ptr<Buffer> app_metadata_;
 };
 
-// The next two classes implement writing to a FlightData stream.
-// Similarly to the read side, we want to reuse the implementation of
-// RecordBatchWriter. As a result, these two classes are intertwined
-// in order to pass application metadata "through" RecordBatchWriter.
-// In order to get application-specific metadata to the
-// IpcPayloadWriter, DoPutPayloadWriter takes a pointer to
-// GrpcStreamWriter. GrpcStreamWriter updates a metadata field on
-// write; DoPutPayloadWriter reads that metadata field to determine
-// what to write.
-
-template <typename ProtoReadT, typename FlightReadT>
-class DoPutPayloadWriter;
-
-template <typename ProtoReadT, typename FlightReadT>
-class GrpcStreamWriter : public FlightStreamWriter {
+FlightMetadataReader::~FlightMetadataReader() = default;
+
+/// \brief The base of the ClientDataStream implementation for gRPC.
+template <typename Stream, typename ReadPayload>
+class FinishableDataStream : public internal::ClientDataStream {
  public:
-  ~GrpcStreamWriter() override = default;
+  FinishableDataStream(std::shared_ptr<ClientRpc> rpc, std::shared_ptr<Stream> 
stream,
+                       std::shared_ptr<MemoryManager> memory_manager)
+      : rpc_(std::move(rpc)),
+        stream_(std::move(stream)),
+        memory_manager_(memory_manager ? std::move(memory_manager)
+                                       : 
CPUDevice::Instance()->default_memory_manager()),
+        finished_(false) {}
 
-  using GrpcStream = grpc::ClientReaderWriter<pb::FlightData, ProtoReadT>;
+  Status Finish() override {
+    if (finished_) {
+      return server_status_;
+    }
 
-  explicit GrpcStreamWriter(
-      const FlightDescriptor& descriptor, std::shared_ptr<ClientRpc> rpc,
-      int64_t write_size_limit_bytes, const ipc::IpcWriteOptions& options,
-      std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> 
writer)
-      : app_metadata_(nullptr),
-        batch_writer_(nullptr),
-        writer_(std::move(writer)),
-        rpc_(std::move(rpc)),
-        write_size_limit_bytes_(write_size_limit_bytes),
-        options_(options),
-        descriptor_(descriptor),
-        writer_closed_(false) {}
+    // Drain the read side, as otherwise gRPC Finish() will hang. We
+    // only call Finish() when the client closes the writer or the
+    // reader finishes, so it's OK to assume the client no longer
+    // wants to read and drain the read side. (If the client wants to
+    // indicate that it is done writing, but not done reading, it
+    // should use DoneWriting.
+    ReadPayload message;
+    while (internal::ReadPayload(stream_.get(), &message)) {
+      // Drain the read side to avoid gRPC hanging in Finish()
+    }
 
-  static Status Open(
-      const FlightDescriptor& descriptor, std::shared_ptr<Schema> schema,
-      const ipc::IpcWriteOptions& options, std::shared_ptr<ClientRpc> rpc,
-      int64_t write_size_limit_bytes,
-      std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> 
writer,
-      std::unique_ptr<FlightStreamWriter>* out);
+    server_status_ = internal::FromGrpcStatus(stream_->Finish(), 
&rpc_->context);
+    if (!server_status_.ok()) {
+      server_status_ = Status::FromDetailAndArgs(
+          server_status_.code(), server_status_.detail(), 
server_status_.message(),
+          ". gRPC client debug context: ", rpc_->context.debug_error_string());
+    }
+    finished_ = true;
 
-  Status CheckStarted() {
-    if (!batch_writer_) {
-      return Status::Invalid("Writer not initialized. Call Begin() with a 
schema.");
+    return server_status_;
+  }
+  void TryCancel() override { rpc_->context.TryCancel(); }
+
+  std::shared_ptr<ClientRpc> rpc_;
+  std::shared_ptr<Stream> stream_;
+  std::shared_ptr<MemoryManager> memory_manager_;
+  bool finished_;
+  Status server_status_;
+};
+
+/// \brief A ClientDataStream implementation for gRPC that manages a
+///   mutex to protect from concurrent reads/writes, and drains the
+///   read side on finish.
+template <typename Stream, typename ReadPayload>
+class WritableDataStream : public FinishableDataStream<Stream, ReadPayload> {
+ public:
+  using Base = FinishableDataStream<Stream, ReadPayload>;
+  WritableDataStream(std::shared_ptr<ClientRpc> rpc, std::shared_ptr<Stream> 
stream,
+                     std::shared_ptr<MemoryManager> memory_manager)
+      : Base(std::move(rpc), std::move(stream), std::move(memory_manager)),
+        read_mutex_(),
+        finish_mutex_(),
+        done_writing_(false) {}
+
+  Status WritesDone() override {
+    // This is only used by the writer side of a stream, so it need
+    // not be protected with a lock.
+    if (done_writing_) {
+      return Status::OK();
+    }
+    done_writing_ = true;
+    if (!stream_->WritesDone()) {
+      // Error happened, try to close the stream to get more detailed info
+      return internal::ClientDataStream::Finish(MakeFlightError(
+          FlightStatusCode::Internal, "Could not flush pending record 
batches"));
     }
     return Status::OK();
   }
 
-  Status Begin(const std::shared_ptr<Schema>& schema,
-               const ipc::IpcWriteOptions& options) override {
-    if (batch_writer_) {
-      return Status::Invalid("This writer has already been started.");
+  Status Finish() override {
+    // This may be used concurrently by reader/writer side of a
+    // stream, so it needs to be protected.
+    std::lock_guard<std::mutex> guard(finish_mutex_);
+
+    // Now that we're shared between a reader and writer, we need to
+    // protect ourselves from being called while there's an
+    // outstanding read.
+    std::unique_lock<std::mutex> read_guard(read_mutex_, std::try_to_lock);
+    if (!read_guard.owns_lock()) {
+      return MakeFlightError(FlightStatusCode::Internal,
+                             "Cannot close stream with pending read 
operation.");
     }
-    std::unique_ptr<ipc::internal::IpcPayloadWriter> payload_writer(
-        new DoPutPayloadWriter<ProtoReadT, FlightReadT>(
-            descriptor_, std::move(rpc_), write_size_limit_bytes_, writer_, 
this));
-    // XXX: this does not actually write the message to the stream.
-    // See Close().
-    ARROW_ASSIGN_OR_RAISE(batch_writer_, ipc::internal::OpenRecordBatchWriter(
-                                             std::move(payload_writer), 
schema, options));
-    return Status::OK();
+
+    // Try to flush pending writes. Don't use our WritesDone() to
+    // avoid recursion.
+    bool finished_writes = done_writing_ || stream_->WritesDone();
+    done_writing_ = true;
+
+    Status st = Base::Finish();
+    if (!finished_writes) {
+      return Status::FromDetailAndArgs(
+          st.code(), st.detail(), st.message(),
+          ". Additionally, could not finish writing record batches before 
closing");
+    }
+    return st;
   }
 
-  Status Begin(const std::shared_ptr<Schema>& schema) override {
-    return Begin(schema, options_);
+  using Base::stream_;
+  std::mutex read_mutex_;
+  std::mutex finish_mutex_;
+  bool done_writing_;
+};
+
+class GrpcClientGetStream
+    : public FinishableDataStream<grpc::ClientReader<pb::FlightData>,
+                                  internal::FlightData> {
+ public:
+  using FinishableDataStream::FinishableDataStream;
+
+  bool ReadData(internal::FlightData* data) override {
+    bool success = internal::ReadPayload(stream_.get(), data);
+    if (ARROW_PREDICT_FALSE(!success)) return false;
+    if (data->body &&
+        
ARROW_PREDICT_FALSE(!data->body->device()->Equals(*memory_manager_->device()))) 
{

Review comment:
       Ah, fair point. I removed the extra check.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to