This is an automated email from the ASF dual-hosted git repository.

mneumann pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 4533271b4b feat: expose DoGet response headers & trailers (#4727)
4533271b4b is described below

commit 4533271b4b221a5e28fa3215cb3cbddaafafdd84
Author: Marco Neumann <ma...@crepererum.net>
AuthorDate: Fri Aug 25 12:45:20 2023 +0200

    feat: expose DoGet response headers & trailers (#4727)
    
    * feat: expose DoGet response headers & trailers
    
    * docs: improve
    
    Co-authored-by: Andrew Lamb <and...@nerdnetworks.org>
    
    * refactor: address review comments
    
    ---------
    
    Co-authored-by: Andrew Lamb <and...@nerdnetworks.org>
---
 arrow-flight/Cargo.toml                     |   3 +
 arrow-flight/src/client.rs                  |  20 ++--
 arrow-flight/src/decode.rs                  |  44 ++++++++-
 arrow-flight/src/lib.rs                     |   3 +
 arrow-flight/src/trailers.rs                |  97 +++++++++++++++++++
 arrow-flight/tests/client.rs                |  34 ++++++-
 arrow-flight/tests/common/server.rs         |   6 +-
 arrow-flight/tests/common/trailers_layer.rs | 138 ++++++++++++++++++++++++++++
 8 files changed, 327 insertions(+), 18 deletions(-)

diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml
index 3ed426a21f..1a53dbddb1 100644
--- a/arrow-flight/Cargo.toml
+++ b/arrow-flight/Cargo.toml
@@ -67,6 +67,9 @@ cli = ["arrow-cast/prettyprint", "clap", "tracing-log", 
"tracing-subscriber", "t
 [dev-dependencies]
 arrow-cast = { workspace = true, features = ["prettyprint"] }
 assert_cmd = "2.0.8"
+http = "0.2.9"
+http-body = "0.4.5"
+pin-project-lite = "0.2"
 tempfile = "3.3"
 tokio-stream = { version = "0.1", features = ["net"] }
 tower = "0.4.13"
diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index 2c952fb3bf..8793f7834b 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -18,9 +18,9 @@
 use std::task::Poll;
 
 use crate::{
-    decode::FlightRecordBatchStream, 
flight_service_client::FlightServiceClient, Action,
-    ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
-    HandshakeRequest, PutResult, Ticket,
+    decode::FlightRecordBatchStream, 
flight_service_client::FlightServiceClient,
+    trailers::extract_lazy_trailers, Action, ActionType, Criteria, Empty, 
FlightData,
+    FlightDescriptor, FlightInfo, HandshakeRequest, PutResult, Ticket,
 };
 use arrow_schema::Schema;
 use bytes::Bytes;
@@ -204,16 +204,14 @@ impl FlightClient {
     pub async fn do_get(&mut self, ticket: Ticket) -> 
Result<FlightRecordBatchStream> {
         let request = self.make_request(ticket);
 
-        let response_stream = self
-            .inner
-            .do_get(request)
-            .await?
-            .into_inner()
-            .map_err(FlightError::Tonic);
+        let (md, response_stream, _ext) = 
self.inner.do_get(request).await?.into_parts();
+        let (response_stream, trailers) = 
extract_lazy_trailers(response_stream);
 
         Ok(FlightRecordBatchStream::new_from_flight_data(
-            response_stream,
-        ))
+            response_stream.map_err(FlightError::Tonic),
+        )
+        .with_headers(md)
+        .with_trailers(trailers))
     }
 
     /// Make a `GetFlightInfo` call to the server with the provided
diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs
index df74923332..dfcdd26060 100644
--- a/arrow-flight/src/decode.rs
+++ b/arrow-flight/src/decode.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::{utils::flight_data_to_arrow_batch, FlightData};
+use crate::{trailers::LazyTrailers, utils::flight_data_to_arrow_batch, 
FlightData};
 use arrow_array::{ArrayRef, RecordBatch};
 use arrow_buffer::Buffer;
 use arrow_schema::{Schema, SchemaRef};
@@ -24,6 +24,7 @@ use futures::{ready, stream::BoxStream, Stream, StreamExt};
 use std::{
     collections::HashMap, convert::TryFrom, fmt::Debug, pin::Pin, sync::Arc, 
task::Poll,
 };
+use tonic::metadata::MetadataMap;
 
 use crate::error::{FlightError, Result};
 
@@ -82,13 +83,23 @@ use crate::error::{FlightError, Result};
 /// ```
 #[derive(Debug)]
 pub struct FlightRecordBatchStream {
+    /// Optional grpc header metadata.
+    headers: MetadataMap,
+
+    /// Optional grpc trailer metadata.
+    trailers: Option<LazyTrailers>,
+
     inner: FlightDataDecoder,
 }
 
 impl FlightRecordBatchStream {
     /// Create a new [`FlightRecordBatchStream`] from a decoded stream
     pub fn new(inner: FlightDataDecoder) -> Self {
-        Self { inner }
+        Self {
+            inner,
+            headers: MetadataMap::default(),
+            trailers: None,
+        }
     }
 
     /// Create a new [`FlightRecordBatchStream`] from a stream of 
[`FlightData`]
@@ -98,9 +109,37 @@ impl FlightRecordBatchStream {
     {
         Self {
             inner: FlightDataDecoder::new(inner),
+            headers: MetadataMap::default(),
+            trailers: None,
+        }
+    }
+
+    /// Record response headers.
+    pub fn with_headers(self, headers: MetadataMap) -> Self {
+        Self { headers, ..self }
+    }
+
+    /// Record response trailers.
+    pub fn with_trailers(self, trailers: LazyTrailers) -> Self {
+        Self {
+            trailers: Some(trailers),
+            ..self
         }
     }
 
+    /// Headers attached to this stream.
+    pub fn headers(&self) -> &MetadataMap {
+        &self.headers
+    }
+
+    /// Trailers attached to this stream.
+    ///
+    /// Note that this will return `None` until the entire stream is consumed.
+    /// Only after calling `next()` returns `None`, might any available 
trailers be returned.
+    pub fn trailers(&self) -> Option<MetadataMap> {
+        self.trailers.as_ref().and_then(|trailers| trailers.get())
+    }
+
     /// Has a message defining the schema been received yet?
     #[deprecated = "use schema().is_some() instead"]
     pub fn got_schema(&self) -> bool {
@@ -117,6 +156,7 @@ impl FlightRecordBatchStream {
         self.inner
     }
 }
+
 impl futures::Stream for FlightRecordBatchStream {
     type Item = Result<RecordBatch>;
 
diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs
index 4163f2ceaa..04edf26638 100644
--- a/arrow-flight/src/lib.rs
+++ b/arrow-flight/src/lib.rs
@@ -111,6 +111,9 @@ pub use gen::Result;
 pub use gen::SchemaResult;
 pub use gen::Ticket;
 
+/// Helper to extract HTTP/gRPC trailers from a tonic stream.
+mod trailers;
+
 pub mod utils;
 
 #[cfg(feature = "flight-sql-experimental")]
diff --git a/arrow-flight/src/trailers.rs b/arrow-flight/src/trailers.rs
new file mode 100644
index 0000000000..d652542da7
--- /dev/null
+++ b/arrow-flight/src/trailers.rs
@@ -0,0 +1,97 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{
+    pin::Pin,
+    sync::{Arc, Mutex},
+    task::{Context, Poll},
+};
+
+use futures::{ready, FutureExt, Stream, StreamExt};
+use tonic::{metadata::MetadataMap, Status, Streaming};
+
+/// Extract [`LazyTrailers`] from [`Streaming`] [tonic] response.
+///
+/// Note that [`LazyTrailers`] has inner mutability and will only hold actual 
data after [`ExtractTrailersStream`] is
+/// fully consumed (dropping it is not required though).
+pub fn extract_lazy_trailers<T>(
+    s: Streaming<T>,
+) -> (ExtractTrailersStream<T>, LazyTrailers) {
+    let trailers: SharedTrailers = Default::default();
+    let stream = ExtractTrailersStream {
+        inner: s,
+        trailers: Arc::clone(&trailers),
+    };
+    let lazy_trailers = LazyTrailers { trailers };
+    (stream, lazy_trailers)
+}
+
+type SharedTrailers = Arc<Mutex<Option<MetadataMap>>>;
+
+/// [Stream] that stores the gRPC trailers into [`LazyTrailers`].
+///
+/// See [`extract_lazy_trailers`] for construction.
+#[derive(Debug)]
+pub struct ExtractTrailersStream<T> {
+    inner: Streaming<T>,
+    trailers: SharedTrailers,
+}
+
+impl<T> Stream for ExtractTrailersStream<T> {
+    type Item = Result<T, Status>;
+
+    fn poll_next(
+        mut self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        let res = ready!(self.inner.poll_next_unpin(cx));
+
+        if res.is_none() {
+            // stream exhausted => trailers should available
+            if let Some(trailers) = self
+                .inner
+                .trailers()
+                .now_or_never()
+                .and_then(|res| res.ok())
+                .flatten()
+            {
+                *self.trailers.lock().expect("poisoned") = Some(trailers);
+            }
+        }
+
+        Poll::Ready(res)
+    }
+
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        self.inner.size_hint()
+    }
+}
+
+/// gRPC trailers that are extracted by [`ExtractTrailersStream`].
+///
+/// See [`extract_lazy_trailers`] for construction.
+#[derive(Debug)]
+pub struct LazyTrailers {
+    trailers: SharedTrailers,
+}
+
+impl LazyTrailers {
+    /// gRPC trailers that are known at the end of a stream.
+    pub fn get(&self) -> Option<MetadataMap> {
+        self.trailers.lock().expect("poisoned").clone()
+    }
+}
diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs
index 8ea542879a..1b9891e121 100644
--- a/arrow-flight/tests/client.rs
+++ b/arrow-flight/tests/client.rs
@@ -19,6 +19,7 @@
 
 mod common {
     pub mod server;
+    pub mod trailers_layer;
 }
 use arrow_array::{RecordBatch, UInt64Array};
 use arrow_flight::{
@@ -28,7 +29,7 @@ use arrow_flight::{
 };
 use arrow_schema::{DataType, Field, Schema};
 use bytes::Bytes;
-use common::server::TestFlightServer;
+use common::{server::TestFlightServer, trailers_layer::TrailersLayer};
 use futures::{Future, StreamExt, TryStreamExt};
 use tokio::{net::TcpListener, task::JoinHandle};
 use tonic::{
@@ -158,18 +159,42 @@ async fn test_do_get() {
 
         let response = vec![Ok(batch.clone())];
         test_server.set_do_get_response(response);
-        let response_stream = client
+        let mut response_stream = client
             .do_get(ticket.clone())
             .await
             .expect("error making request");
 
+        assert_eq!(
+            response_stream
+                .headers()
+                .get("test-resp-header")
+                .expect("header exists")
+                .to_str()
+                .unwrap(),
+            "some_val",
+        );
+
+        // trailers are not available before stream exhaustion
+        assert!(response_stream.trailers().is_none());
+
         let expected_response = vec![batch];
-        let response: Vec<_> = response_stream
+        let response: Vec<_> = (&mut response_stream)
             .try_collect()
             .await
             .expect("Error streaming data");
-
         assert_eq!(response, expected_response);
+
+        assert_eq!(
+            response_stream
+                .trailers()
+                .expect("stream exhausted")
+                .get("test-trailer")
+                .expect("trailer exists")
+                .to_str()
+                .unwrap(),
+            "trailer_val",
+        );
+
         assert_eq!(test_server.take_do_get_request(), Some(ticket));
         ensure_metadata(&client, &test_server);
     })
@@ -932,6 +957,7 @@ impl TestFixture {
 
         let serve_future = tonic::transport::Server::builder()
             .timeout(server_timeout)
+            .layer(TrailersLayer)
             .add_service(test_server.service())
             .serve_with_incoming_shutdown(
                 tokio_stream::wrappers::TcpListenerStream::new(listener),
diff --git a/arrow-flight/tests/common/server.rs 
b/arrow-flight/tests/common/server.rs
index b87019d632..c575d12bbf 100644
--- a/arrow-flight/tests/common/server.rs
+++ b/arrow-flight/tests/common/server.rs
@@ -359,7 +359,11 @@ impl FlightService for TestFlightServer {
             .build(batch_stream)
             .map_err(Into::into);
 
-        Ok(Response::new(stream.boxed()))
+        let mut resp = Response::new(stream.boxed());
+        resp.metadata_mut()
+            .insert("test-resp-header", "some_val".parse().unwrap());
+
+        Ok(resp)
     }
 
     async fn do_put(
diff --git a/arrow-flight/tests/common/trailers_layer.rs 
b/arrow-flight/tests/common/trailers_layer.rs
new file mode 100644
index 0000000000..9e6be0dcf0
--- /dev/null
+++ b/arrow-flight/tests/common/trailers_layer.rs
@@ -0,0 +1,138 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::future::Future;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+use futures::ready;
+use http::{HeaderValue, Request, Response};
+use http_body::SizeHint;
+use pin_project_lite::pin_project;
+use tower::{Layer, Service};
+
+#[derive(Debug, Copy, Clone, Default)]
+pub struct TrailersLayer;
+
+impl<S> Layer<S> for TrailersLayer {
+    type Service = TrailersService<S>;
+
+    fn layer(&self, service: S) -> Self::Service {
+        TrailersService { service }
+    }
+}
+
+#[derive(Debug, Clone)]
+pub struct TrailersService<S> {
+    service: S,
+}
+
+impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for TrailersService<S>
+where
+    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
+    ResBody: http_body::Body,
+{
+    type Response = Response<WrappedBody<ResBody>>;
+    type Error = S::Error;
+    type Future = WrappedFuture<S::Future>;
+
+    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), 
Self::Error>> {
+        self.service.poll_ready(cx)
+    }
+
+    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
+        WrappedFuture {
+            inner: self.service.call(request),
+        }
+    }
+}
+
+pin_project! {
+    #[derive(Debug)]
+    pub struct WrappedFuture<F> {
+        #[pin]
+        inner: F,
+    }
+}
+
+impl<F, ResBody, Error> Future for WrappedFuture<F>
+where
+    F: Future<Output = Result<Response<ResBody>, Error>>,
+    ResBody: http_body::Body,
+{
+    type Output = Result<Response<WrappedBody<ResBody>>, Error>;
+
+    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> 
Poll<Self::Output> {
+        let result: Result<Response<ResBody>, Error> =
+            ready!(self.as_mut().project().inner.poll(cx));
+
+        match result {
+            Ok(response) => {
+                Poll::Ready(Ok(response.map(|body| WrappedBody { inner: body 
})))
+            }
+            Err(e) => Poll::Ready(Err(e)),
+        }
+    }
+}
+
+pin_project! {
+    #[derive(Debug)]
+    pub struct WrappedBody<B> {
+        #[pin]
+        inner: B,
+    }
+}
+
+impl<B: http_body::Body> http_body::Body for WrappedBody<B> {
+    type Data = B::Data;
+    type Error = B::Error;
+
+    fn poll_data(
+        mut self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
+        self.as_mut().project().inner.poll_data(cx)
+    }
+
+    fn poll_trailers(
+        mut self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Result<Option<http::header::HeaderMap>, Self::Error>> {
+        let result: Result<Option<http::header::HeaderMap>, Self::Error> =
+            ready!(self.as_mut().project().inner.poll_trailers(cx));
+
+        let mut trailers = http::header::HeaderMap::new();
+        trailers.insert("test-trailer", 
HeaderValue::from_static("trailer_val"));
+
+        match result {
+            Ok(Some(mut existing)) => {
+                existing.extend(trailers.iter().map(|(k, v)| (k.clone(), 
v.clone())));
+                Poll::Ready(Ok(Some(existing)))
+            }
+            Ok(None) => Poll::Ready(Ok(Some(trailers))),
+            Err(e) => Poll::Ready(Err(e)),
+        }
+    }
+
+    fn is_end_stream(&self) -> bool {
+        self.inner.is_end_stream()
+    }
+
+    fn size_hint(&self) -> SizeHint {
+        self.inner.size_hint()
+    }
+}

Reply via email to