lidavidm commented on a change in pull request #12465: URL: https://github.com/apache/arrow/pull/12465#discussion_r818848309
########## File path: cpp/src/arrow/flight/transport/grpc/grpc_server.cc ########## @@ -0,0 +1,634 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// gRPC transport implementation for Arrow Flight + +#include "arrow/flight/transport/grpc/grpc_server.h" + +#include <mutex> +#include <sstream> +#include <string> +#include <unordered_map> +#include <utility> + +#include "arrow/util/config.h" +#ifdef GRPCPP_PP_INCLUDE +#include <grpcpp/grpcpp.h> +#else +#include <grpc++/grpc++.h> +#endif + +#include "arrow/buffer.h" +#include "arrow/flight/internal.h" +#include "arrow/flight/serialization_internal.h" +#include "arrow/flight/server.h" +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/transport.h" +#include "arrow/flight/transport_server.h" +#include "arrow/flight/types.h" +#include "arrow/util/logging.h" +#include "arrow/util/uri.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace grpc { + +namespace pb = arrow::flight::protocol; +using FlightService = pb::FlightService; +using ServerContext = ::grpc::ServerContext; +template <typename T> +using ServerWriter = ::grpc::ServerWriter<T>; + +// Macro that runs interceptors before returning the given status +#define RETURN_WITH_MIDDLEWARE(CONTEXT, STATUS) \ + do { \ + const auto& __s = (STATUS); \ + return CONTEXT.FinishRequest(__s); \ + } while (false) +#define CHECK_ARG_NOT_NULL(CONTEXT, VAL, MESSAGE) \ + if (VAL == nullptr) { \ + RETURN_WITH_MIDDLEWARE( \ + CONTEXT, ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT, MESSAGE)); \ + } +// Same as RETURN_NOT_OK, but accepts either Arrow or gRPC status, and +// will run interceptors +#define SERVICE_RETURN_NOT_OK(CONTEXT, expr) \ + do { \ + const auto& _s = (expr); \ + if (ARROW_PREDICT_FALSE(!_s.ok())) { \ + return CONTEXT.FinishRequest(_s); \ + } \ + } while (false) + +namespace { +class GrpcServerAuthReader : public ServerAuthReader { + public: + explicit GrpcServerAuthReader( + ::grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream) + : stream_(stream) {} + + Status Read(std::string* token) override { + pb::HandshakeRequest request; + if (stream_->Read(&request)) { + *token = std::move(*request.mutable_payload()); + return Status::OK(); + } + return Status::IOError("Stream is closed."); + } + + private: + ::grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream_; +}; + +class GrpcServerAuthSender : public ServerAuthSender { + public: + explicit GrpcServerAuthSender( + ::grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream) + : stream_(stream) {} + + Status Write(const std::string& token) override { + pb::HandshakeResponse response; + response.set_payload(token); + if (stream_->Write(response)) { + return Status::OK(); + } + return Status::IOError("Stream was closed."); + } + + private: + ::grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream_; +}; + +class GrpcServerCallContext : public ServerCallContext { + explicit GrpcServerCallContext(::grpc::ServerContext* context) + : context_(context), peer_(context_->peer()) {} + + const std::string& peer_identity() const override { return peer_identity_; } + const std::string& peer() const override { return peer_; } + bool is_cancelled() const override { return context_->IsCancelled(); } + + // Helper method that runs interceptors given the result of an RPC, + // then returns the final gRPC status to send to the client + ::grpc::Status FinishRequest(const ::grpc::Status& status) { + // Don't double-convert status - return the original one here + FinishRequest(internal::FromGrpcStatus(status)); + return status; + } + + ::grpc::Status FinishRequest(const arrow::Status& status) { + for (const auto& instance : middleware_) { + instance->CallCompleted(status); + } + + // Set custom headers to map the exact Arrow status for clients + // who want it. + return internal::ToGrpcStatus(status, context_); + } + + ServerMiddleware* GetMiddleware(const std::string& key) const override { Review comment: Yes, the same goes for auth. I just haven't focused on that so far (the focus was on just getting UCX up and running and then the core changes were split out here). I'll file issues to follow up for these shortly. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
