This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new c1507ad20a generic channel support for FlightClient (#9933)
c1507ad20a is described below
commit c1507ad20a3dad44353bd9fb4c489785298d10d8
Author: Rostislav Rumenov <[email protected]>
AuthorDate: Thu May 7 20:10:50 2026 +0200
generic channel support for FlightClient (#9933)
Allow FlightServiceClient to be parameterized over the underlying
channel type, so
users can wrap a tonic channel with custom interceptors or services.
Motivation: Annotating outbound Flight requests with metadata (e.g.
injecting
OpenTelemetry trace context into headers) currently requires forking or
wrapping at
a higher level. Making the channel generic lets callers compose tower
layers/interceptors idiomatically and propagate distributed tracing
context without
bespoke plumbing.
---------
Co-authored-by: Rostislav Rumenov <[email protected]>
---
arrow-flight/src/client.rs | 291 +++++++++++++++++++++++++++++++++++++++++++--
1 file changed, 280 insertions(+), 11 deletions(-)
diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index dac086271c..b2059a81d0 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -31,6 +31,7 @@ use futures::{
stream::{self, BoxStream},
};
use prost::Message;
+use tonic::codegen::{Body, StdError};
use tonic::{metadata::MetadataMap, transport::Channel};
use crate::error::{FlightError, Result};
@@ -67,22 +68,28 @@ use crate::streams::{FallibleRequestStream,
FallibleTonicResponseStream};
/// # }
/// ```
#[derive(Debug)]
-pub struct FlightClient {
+pub struct FlightClient<T = Channel> {
/// Optional grpc header metadata to include with each request
metadata: MetadataMap,
/// The inner client
- inner: FlightServiceClient<Channel>,
+ inner: FlightServiceClient<T>,
}
-impl FlightClient {
- /// Creates a client client with the provided [`Channel`]
- pub fn new(channel: Channel) -> Self {
- Self::new_from_inner(FlightServiceClient::new(channel))
+impl<T> FlightClient<T>
+where
+ T: tonic::client::GrpcService<tonic::body::Body>,
+ T::Error: Into<StdError>,
+ T::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
+ <T::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
+{
+ /// Creates a client with the provided transport
+ pub fn new(inner: T) -> Self {
+ Self::new_from_inner(FlightServiceClient::new(inner))
}
/// Creates a new higher level client with the provided lower level client
- pub fn new_from_inner(inner: FlightServiceClient<Channel>) -> Self {
+ pub fn new_from_inner(inner: FlightServiceClient<T>) -> Self {
Self {
metadata: MetadataMap::new(),
inner,
@@ -120,19 +127,19 @@ impl FlightClient {
/// Return a reference to the underlying tonic
/// [`FlightServiceClient`]
- pub fn inner(&self) -> &FlightServiceClient<Channel> {
+ pub fn inner(&self) -> &FlightServiceClient<T> {
&self.inner
}
/// Return a mutable reference to the underlying tonic
/// [`FlightServiceClient`]
- pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
+ pub fn inner_mut(&mut self) -> &mut FlightServiceClient<T> {
&mut self.inner
}
/// Consume this client and return the underlying tonic
/// [`FlightServiceClient`]
- pub fn into_inner(self) -> FlightServiceClient<Channel> {
+ pub fn into_inner(self) -> FlightServiceClient<T> {
self.inner
}
@@ -664,10 +671,272 @@ impl FlightClient {
}
/// return a Request, adding any configured metadata
- fn make_request<T>(&self, t: T) -> tonic::Request<T> {
+ fn make_request<R>(&self, t: R) -> tonic::Request<R> {
// Pass along metadata
let mut request = tonic::Request::new(t);
*request.metadata_mut() = self.metadata.clone();
request
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::FlightClient;
+ use crate::encode::FlightDataEncoderBuilder;
+ use crate::flight_service_server::{FlightService, FlightServiceServer};
+ use crate::{
+ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor,
FlightInfo,
+ HandshakeRequest, HandshakeResponse, PollInfo, PutResult,
SchemaResult, Ticket,
+ };
+ use arrow_array::{RecordBatch, UInt64Array};
+ use bytes::Bytes;
+ use futures::{StreamExt, TryStreamExt, stream::BoxStream};
+ use std::net::SocketAddr;
+ use std::sync::{Arc, Mutex};
+ use std::time::Duration;
+ use tokio::net::TcpListener;
+ use tokio::task::JoinHandle;
+ use tonic::metadata::MetadataMap;
+ use tonic::service::interceptor::InterceptedService;
+ use tonic::transport::Channel;
+ use tonic::{Request, Response, Status, Streaming};
+ use uuid::Uuid;
+
+ /// Minimal `FlightService` that records request metadata and serves a
+ /// configured `do_get` response. Other RPCs return `Unimplemented`.
+ #[derive(Debug, Clone, Default)]
+ struct InterceptorTestServer {
+ state: Arc<Mutex<InterceptorTestState>>,
+ }
+
+ #[derive(Debug, Default)]
+ struct InterceptorTestState {
+ do_get_request: Option<Ticket>,
+ do_get_response: Option<Vec<Result<RecordBatch, Status>>>,
+ last_request_metadata: Option<MetadataMap>,
+ }
+
+ impl InterceptorTestServer {
+ fn save_metadata<T>(&self, request: &Request<T>) {
+ self.state.lock().unwrap().last_request_metadata =
Some(request.metadata().clone());
+ }
+
+ fn set_do_get_response(&self, response: Vec<Result<RecordBatch,
Status>>) {
+ self.state.lock().unwrap().do_get_response = Some(response);
+ }
+
+ fn take_do_get_request(&self) -> Option<Ticket> {
+ self.state.lock().unwrap().do_get_request.take()
+ }
+
+ fn take_last_request_metadata(&self) -> Option<MetadataMap> {
+ self.state.lock().unwrap().last_request_metadata.take()
+ }
+ }
+
+ #[tonic::async_trait]
+ impl FlightService for InterceptorTestServer {
+ type HandshakeStream = BoxStream<'static, Result<HandshakeResponse,
Status>>;
+ type ListFlightsStream = BoxStream<'static, Result<FlightInfo,
Status>>;
+ type DoGetStream = BoxStream<'static, Result<FlightData, Status>>;
+ type DoPutStream = BoxStream<'static, Result<PutResult, Status>>;
+ type DoActionStream = BoxStream<'static, Result<crate::Result,
Status>>;
+ type ListActionsStream = BoxStream<'static, Result<ActionType,
Status>>;
+ type DoExchangeStream = BoxStream<'static, Result<FlightData, Status>>;
+
+ async fn do_get(
+ &self,
+ request: Request<Ticket>,
+ ) -> Result<Response<Self::DoGetStream>, Status> {
+ self.save_metadata(&request);
+ let mut state = self.state.lock().unwrap();
+ state.do_get_request = Some(request.into_inner());
+
+ let batches = state
+ .do_get_response
+ .take()
+ .ok_or_else(|| Status::internal("no do_get response
configured"))?;
+ let batch_stream =
futures::stream::iter(batches).map_err(Into::into);
+ let stream = FlightDataEncoderBuilder::new()
+ .build(batch_stream)
+ .map_err(Into::into);
+ Ok(Response::new(stream.boxed()))
+ }
+
+ async fn handshake(
+ &self,
+ _: Request<Streaming<HandshakeRequest>>,
+ ) -> Result<Response<Self::HandshakeStream>, Status> {
+ Err(Status::unimplemented(""))
+ }
+ async fn list_flights(
+ &self,
+ _: Request<Criteria>,
+ ) -> Result<Response<Self::ListFlightsStream>, Status> {
+ Err(Status::unimplemented(""))
+ }
+ async fn get_flight_info(
+ &self,
+ _: Request<FlightDescriptor>,
+ ) -> Result<Response<FlightInfo>, Status> {
+ Err(Status::unimplemented(""))
+ }
+ async fn poll_flight_info(
+ &self,
+ _: Request<FlightDescriptor>,
+ ) -> Result<Response<PollInfo>, Status> {
+ Err(Status::unimplemented(""))
+ }
+ async fn get_schema(
+ &self,
+ _: Request<FlightDescriptor>,
+ ) -> Result<Response<SchemaResult>, Status> {
+ Err(Status::unimplemented(""))
+ }
+ async fn do_put(
+ &self,
+ _: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoPutStream>, Status> {
+ Err(Status::unimplemented(""))
+ }
+ async fn do_action(
+ &self,
+ _: Request<Action>,
+ ) -> Result<Response<Self::DoActionStream>, Status> {
+ Err(Status::unimplemented(""))
+ }
+ async fn list_actions(
+ &self,
+ _: Request<Empty>,
+ ) -> Result<Response<Self::ListActionsStream>, Status> {
+ Err(Status::unimplemented(""))
+ }
+ async fn do_exchange(
+ &self,
+ _: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoExchangeStream>, Status> {
+ Err(Status::unimplemented(""))
+ }
+ }
+
+ /// Spawns the test server on a background task and exposes a connected
channel.
+ struct InterceptorTestFixture {
+ shutdown: Option<tokio::sync::oneshot::Sender<()>>,
+ addr: SocketAddr,
+ handle: Option<JoinHandle<Result<(), tonic::transport::Error>>>,
+ }
+
+ impl InterceptorTestFixture {
+ async fn new(server: FlightServiceServer<InterceptorTestServer>) ->
Self {
+ let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+ let addr = listener.local_addr().unwrap();
+ let (tx, rx) = tokio::sync::oneshot::channel();
+ let shutdown_future = async move {
+ rx.await.ok();
+ };
+ let serve = tonic::transport::Server::builder()
+ .timeout(Duration::from_secs(30))
+ .add_service(server)
+ .serve_with_incoming_shutdown(
+ tokio_stream::wrappers::TcpListenerStream::new(listener),
+ shutdown_future,
+ );
+ let handle = tokio::task::spawn(serve);
+ Self {
+ shutdown: Some(tx),
+ addr,
+ handle: Some(handle),
+ }
+ }
+
+ async fn channel(&self) -> Channel {
+ let url = format!("http://{}", self.addr);
+ tonic::transport::Endpoint::from_shared(url)
+ .expect("valid endpoint")
+ .timeout(Duration::from_secs(30))
+ .connect()
+ .await
+ .expect("error connecting to server")
+ }
+
+ async fn shutdown_and_wait(mut self) {
+ if let Some(tx) = self.shutdown.take() {
+ tx.send(()).expect("server quit early");
+ }
+ if let Some(handle) = self.handle.take() {
+ handle
+ .await
+ .expect("task join error (panic?)")
+ .expect("server error at shutdown");
+ }
+ }
+ }
+
+ /// Integration test: a tonic [`Channel`] wrapped in an
[`InterceptedService`]
+ /// that injects a custom header is passed to [`FlightClient`], and the
server
+ /// observes the header on the request.
+ #[tokio::test]
+ async fn
test_flight_client_with_intercepted_channel_passes_custom_header() {
+ let test_server = InterceptorTestServer::default();
+ let fixture =
+
InterceptorTestFixture::new(FlightServiceServer::new(test_server.clone())).await;
+
+ let channel = fixture.channel().await;
+
+ let header_name = "x-random-header";
+ let header_value = format!("random-{}", Uuid::new_v4());
+ let header_value_for_interceptor = header_value.clone();
+
+ let interceptor = move |mut req: Request<()>| -> Result<Request<()>,
Status> {
+ req.metadata_mut().insert(
+ header_name,
+ header_value_for_interceptor
+ .parse()
+ .expect("valid metadata value"),
+ );
+ Ok(req)
+ };
+
+ let intercepted = InterceptedService::new(channel, interceptor);
+ let mut client = FlightClient::new(intercepted);
+
+ let ticket = Ticket {
+ ticket: Bytes::from("dummy-ticket"),
+ };
+
+ let batch = RecordBatch::try_from_iter(vec![(
+ "col",
+ Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
+ )])
+ .unwrap();
+
+ test_server.set_do_get_response(vec![Ok(batch.clone())]);
+
+ let response_stream = client
+ .do_get(ticket.clone())
+ .await
+ .expect("error making do_get request");
+
+ let response: Vec<RecordBatch> = response_stream
+ .try_collect()
+ .await
+ .expect("error streaming data");
+
+ assert_eq!(response, vec![batch]);
+ assert_eq!(test_server.take_do_get_request(), Some(ticket));
+
+ let metadata = test_server
+ .take_last_request_metadata()
+ .expect("server received headers")
+ .into_headers();
+
+ let received = metadata
+ .get(header_name)
+ .expect("interceptor header missing on server")
+ .to_str()
+ .expect("ascii header value");
+ assert_eq!(received, header_value);
+
+ fixture.shutdown_and_wait().await;
+ }
+}