pitrou commented on a change in pull request #12465: URL: https://github.com/apache/arrow/pull/12465#discussion_r818848832
########## File path: cpp/src/arrow/flight/transport/grpc/grpc_server.cc ########## @@ -0,0 +1,634 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// gRPC transport implementation for Arrow Flight + +#include "arrow/flight/transport/grpc/grpc_server.h" + +#include <mutex> +#include <sstream> +#include <string> +#include <unordered_map> +#include <utility> + +#include "arrow/util/config.h" +#ifdef GRPCPP_PP_INCLUDE +#include <grpcpp/grpcpp.h> +#else +#include <grpc++/grpc++.h> +#endif + +#include "arrow/buffer.h" +#include "arrow/flight/internal.h" +#include "arrow/flight/serialization_internal.h" +#include "arrow/flight/server.h" +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/transport.h" +#include "arrow/flight/transport_server.h" +#include "arrow/flight/types.h" +#include "arrow/util/logging.h" +#include "arrow/util/uri.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace grpc { + +namespace pb = arrow::flight::protocol; +using FlightService = pb::FlightService; +using ServerContext = ::grpc::ServerContext; +template <typename T> +using ServerWriter = ::grpc::ServerWriter<T>; + +// Macro that runs interceptors before returning the given status +#define RETURN_WITH_MIDDLEWARE(CONTEXT, STATUS) \ + do { \ + const auto& __s = (STATUS); \ + return CONTEXT.FinishRequest(__s); \ + } while (false) +#define CHECK_ARG_NOT_NULL(CONTEXT, VAL, MESSAGE) \ + if (VAL == nullptr) { \ + RETURN_WITH_MIDDLEWARE( \ + CONTEXT, ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT, MESSAGE)); \ + } +// Same as RETURN_NOT_OK, but accepts either Arrow or gRPC status, and +// will run interceptors +#define SERVICE_RETURN_NOT_OK(CONTEXT, expr) \ + do { \ + const auto& _s = (expr); \ + if (ARROW_PREDICT_FALSE(!_s.ok())) { \ + return CONTEXT.FinishRequest(_s); \ + } \ + } while (false) + +namespace { +class GrpcServerAuthReader : public ServerAuthReader { + public: + explicit GrpcServerAuthReader( + ::grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream) + : stream_(stream) {} + + Status Read(std::string* token) override { + pb::HandshakeRequest request; + if (stream_->Read(&request)) { + *token = std::move(*request.mutable_payload()); + return Status::OK(); + } + return Status::IOError("Stream is closed."); + } + + private: + ::grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream_; +}; + +class GrpcServerAuthSender : public ServerAuthSender { + public: + explicit GrpcServerAuthSender( + ::grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream) + : stream_(stream) {} + + Status Write(const std::string& token) override { + pb::HandshakeResponse response; + response.set_payload(token); + if (stream_->Write(response)) { + return Status::OK(); + } + return Status::IOError("Stream was closed."); + } + + private: + ::grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream_; +}; + +class GrpcServerCallContext : public ServerCallContext { + explicit GrpcServerCallContext(::grpc::ServerContext* context) + : context_(context), peer_(context_->peer()) {} + + const std::string& peer_identity() const override { return peer_identity_; } + const std::string& peer() const override { return peer_; } + bool is_cancelled() const override { return context_->IsCancelled(); } + + // Helper method that runs interceptors given the result of an RPC, + // then returns the final gRPC status to send to the client + ::grpc::Status FinishRequest(const ::grpc::Status& status) { + // Don't double-convert status - return the original one here + FinishRequest(internal::FromGrpcStatus(status)); + return status; + } + + ::grpc::Status FinishRequest(const arrow::Status& status) { + for (const auto& instance : middleware_) { + instance->CallCompleted(status); + } + + // Set custom headers to map the exact Arrow status for clients + // who want it. + return internal::ToGrpcStatus(status, context_); + } + + ServerMiddleware* GetMiddleware(const std::string& key) const override { + const auto& instance = middleware_map_.find(key); + if (instance == middleware_map_.end()) { + return nullptr; + } + return instance->second.get(); + } + + private: + friend class GrpcServiceHandler; + ServerContext* context_; + std::string peer_; + std::string peer_identity_; + std::vector<std::shared_ptr<ServerMiddleware>> middleware_; + std::unordered_map<std::string, std::shared_ptr<ServerMiddleware>> middleware_map_; +}; + +class GrpcAddServerHeaders : public AddCallHeaders { + public: + explicit GrpcAddServerHeaders(::grpc::ServerContext* context) : context_(context) {} + ~GrpcAddServerHeaders() override = default; + + void AddHeader(const std::string& key, const std::string& value) override { + context_->AddInitialMetadata(key, value); + } + + private: + ::grpc::ServerContext* context_; +}; + +// A ServerDataStream for streaming data to the client. +class GetDataStream : public internal::ServerDataStream { + public: + explicit GetDataStream(ServerWriter<pb::FlightData>* writer) : writer_(writer) {} + + arrow::Result<bool> WriteData(const FlightPayload& payload) override { + return internal::WritePayload(payload, writer_); + } + + private: + ServerWriter<pb::FlightData>* writer_; +}; + +// A ServerDataStream for reading data from the client. +class PutDataStream final : public internal::ServerDataStream { + public: + explicit PutDataStream( + ::grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* stream) + : stream_(stream) {} + + bool ReadData(internal::FlightData* data) override { + return internal::ReadPayload(&*stream_, data); + } + Status WritePutMetadata(const Buffer& metadata) override { + pb::PutResult message{}; + message.set_app_metadata(metadata.data(), metadata.size()); + if (stream_->Write(message)) { + return Status::OK(); + } + return Status::IOError("Unknown error writing metadata."); + } + + private: + ::grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* stream_; +}; + +// A ServerDataStream for a bidirectional data exchange. +class ExchangeDataStream final : public internal::ServerDataStream { + public: + explicit ExchangeDataStream( + ::grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* stream) + : stream_(stream) {} + + bool ReadData(internal::FlightData* data) override { + return internal::ReadPayload(&*stream_, data); + } + arrow::Result<bool> WriteData(const FlightPayload& payload) override { + return internal::WritePayload(payload, stream_); + } + + private: + ::grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* stream_; +}; + +// The gRPC service implementation, which forwards calls to the Flight +// service and bridges between the Flight transport API and gRPC. +class GrpcServiceHandler final : public FlightService::Service { + public: + GrpcServiceHandler( + std::shared_ptr<ServerAuthHandler> auth_handler, + std::vector<std::pair<std::string, std::shared_ptr<ServerMiddlewareFactory>>> + middleware, + internal::ServerTransport* impl) + : auth_handler_(auth_handler), middleware_(middleware), impl_(impl) {} + + template <typename UserType, typename Iterator, typename ProtoType> + ::grpc::Status WriteStream(Iterator* iterator, ServerWriter<ProtoType>* writer) { + if (!iterator) { + return ::grpc::Status(::grpc::StatusCode::INTERNAL, "No items to iterate"); + } + // Write flight info to stream until listing is exhausted + while (true) { + ProtoType pb_value; + std::unique_ptr<UserType> value; + GRPC_RETURN_NOT_OK(iterator->Next(&value)); + if (!value) { + break; + } + GRPC_RETURN_NOT_OK(internal::ToProto(*value, &pb_value)); + + // Blocking write + if (!writer->Write(pb_value)) { + // Write returns false if the stream is closed + break; + } + } + return ::grpc::Status::OK; + } + + template <typename UserType, typename ProtoType> + ::grpc::Status WriteStream(const std::vector<UserType>& values, + ServerWriter<ProtoType>* writer) { + // Write flight info to stream until listing is exhausted + for (const UserType& value : values) { + ProtoType pb_value; + GRPC_RETURN_NOT_OK(internal::ToProto(value, &pb_value)); + // Blocking write + if (!writer->Write(pb_value)) { + // Write returns false if the stream is closed + break; + } + } + return ::grpc::Status::OK; + } + + // Authenticate the client (if applicable) and construct the call context + ::grpc::Status CheckAuth(const FlightMethod& method, ServerContext* context, + GrpcServerCallContext& flight_context) { + if (!auth_handler_) { + const auto auth_context = context->auth_context(); + if (auth_context && auth_context->IsPeerAuthenticated()) { + auto peer_identity = auth_context->GetPeerIdentity(); + flight_context.peer_identity_ = + peer_identity.empty() + ? "" + : std::string(peer_identity.front().begin(), peer_identity.front().end()); + } else { + flight_context.peer_identity_ = ""; + } + } else { + const auto client_metadata = context->client_metadata(); + const auto auth_header = client_metadata.find(internal::kGrpcAuthHeader); + std::string token; + if (auth_header == client_metadata.end()) { + token = ""; + } else { + token = std::string(auth_header->second.data(), auth_header->second.length()); + } + GRPC_RETURN_NOT_OK(auth_handler_->IsValid(token, &flight_context.peer_identity_)); + } + + return MakeCallContext(method, context, flight_context); + } + + // Authenticate the client (if applicable) and construct the call context + ::grpc::Status MakeCallContext(const FlightMethod& method, ServerContext* context, + GrpcServerCallContext& flight_context) { Review comment: This part as well seems morally transport-agnostic? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org