This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 42ffc3f34 Complete mid-level `FlightClient` (#3402)
42ffc3f34 is described below
commit 42ffc3f344b338289d5e8e6b12b247f791dd5d8f
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Jan 5 09:39:28 2023 -0500
Complete mid-level `FlightClient` (#3402)
* Implement `FlightClient::do_put` and `FlightClient::do_exchange`
* Implement ArrowClient::{list_flights, list_actions, do_action, get_schema}
* Apply suggestions from code review
Co-authored-by: Liang-Chi Hsieh <[email protected]>
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
* remove outdated comment
* make foo/bar placeholders in test more specific
* simplify tests
Co-authored-by: Liang-Chi Hsieh <[email protected]>
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
---
arrow-flight/src/client.rs | 303 +++++++++++++++++++++++++++++++++---
arrow-flight/src/lib.rs | 26 +++-
arrow-flight/tests/client.rs | 297 ++++++++++++++++++++++++++++++-----
arrow-flight/tests/common/server.rs | 85 +++++++++-
4 files changed, 638 insertions(+), 73 deletions(-)
diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index 753c40f2a..bdd51dda4 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -16,11 +16,17 @@
// under the License.
use crate::{
- decode::FlightRecordBatchStream,
flight_service_client::FlightServiceClient,
- FlightDescriptor, FlightInfo, HandshakeRequest, Ticket,
+ decode::FlightRecordBatchStream,
flight_service_client::FlightServiceClient, Action,
+ ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
+ HandshakeRequest, PutResult, Ticket,
};
+use arrow_schema::Schema;
use bytes::Bytes;
-use futures::{future::ready, stream, StreamExt, TryStreamExt};
+use futures::{
+ future::ready,
+ stream::{self, BoxStream},
+ Stream, StreamExt, TryStreamExt,
+};
use tonic::{metadata::MetadataMap, transport::Channel};
use crate::error::{FlightError, Result};
@@ -160,6 +166,11 @@ impl FlightClient {
/// returning a [`FlightRecordBatchStream`] for reading
/// [`RecordBatch`](arrow_array::RecordBatch)es.
///
+ /// # Note
+ ///
+ /// To access the returned [`FlightData`] use
+ /// [`FlightRecordBatchStream::into_inner()`]
+ ///
/// # Example:
/// ```no_run
/// # async fn run() {
@@ -167,12 +178,8 @@ impl FlightClient {
/// # use arrow_flight::FlightClient;
/// # use arrow_flight::Ticket;
/// # use arrow_array::RecordBatch;
- /// # use tonic::transport::Channel;
/// # use futures::stream::TryStreamExt;
- /// # let channel = Channel::from_static("http://localhost:1234")
- /// # .connect()
- /// # .await
- /// # .expect("error connecting");
+ /// # let channel: tonic::transport::Channel = unimplemented!();
/// # let ticket = Ticket { ticket: Bytes::from("foo") };
/// let mut client = FlightClient::new(channel);
///
@@ -199,8 +206,7 @@ impl FlightClient {
.do_get(request)
.await?
.into_inner()
- // convert to FlightError
- .map_err(|e| e.into());
+ .map_err(FlightError::Tonic);
Ok(FlightRecordBatchStream::new_from_flight_data(
response_stream,
@@ -217,11 +223,7 @@ impl FlightClient {
/// # async fn run() {
/// # use arrow_flight::FlightClient;
/// # use arrow_flight::FlightDescriptor;
- /// # use tonic::transport::Channel;
- /// # let channel = Channel::from_static("http://localhost:1234")
- /// # .connect()
- /// # .await
- /// # .expect("error connecting");
+ /// # let channel: tonic::transport::Channel = unimplemented!();
/// let mut client = FlightClient::new(channel);
///
/// // Send a 'CMD' request to the server
@@ -256,13 +258,270 @@ impl FlightClient {
Ok(response)
}
- // TODO other methods
- // list_flights
- // get_schema
- // do_put
- // do_action
- // list_actions
- // do_exchange
+ /// Make a `DoPut` call to the server with the provided
+ /// [`Stream`](futures::Stream) of [`FlightData`] and returning a
+ /// stream of [`PutResult`].
+ ///
+ /// # Example:
+ /// ```no_run
+ /// # async fn run() {
+ /// # use futures::{TryStreamExt, StreamExt};
+ /// # use std::sync::Arc;
+ /// # use arrow_array::UInt64Array;
+ /// # use arrow_array::RecordBatch;
+ /// # use arrow_flight::{FlightClient, FlightDescriptor, PutResult};
+ /// # use arrow_flight::encode::FlightDataEncoderBuilder;
+ /// # let batch = RecordBatch::try_from_iter(vec![
+ /// # ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _)
+ /// # ]).unwrap();
+ /// # let channel: tonic::transport::Channel = unimplemented!();
+ /// let mut client = FlightClient::new(channel);
+ ///
+ /// // encode the batch as a stream of `FlightData`
+ /// let flight_data_stream = FlightDataEncoderBuilder::new()
+ /// .build(futures::stream::iter(vec![Ok(batch)]))
+ /// // data encoder return Results, but do_put requires FlightData
+ /// .map(|batch|batch.unwrap());
+ ///
+ /// // send the stream and get the results as `PutResult`
+ /// let response: Vec<PutResult>= client
+ /// .do_put(flight_data_stream)
+ /// .await
+ /// .unwrap()
+ /// .try_collect() // use TryStreamExt to collect stream
+ /// .await
+ /// .expect("error calling do_put");
+ /// # }
+ /// ```
+ pub async fn do_put<S: Stream<Item = FlightData> + Send + 'static>(
+ &mut self,
+ request: S,
+ ) -> Result<BoxStream<'static, Result<PutResult>>> {
+ let request = self.make_request(request);
+
+ let response = self
+ .inner
+ .do_put(request)
+ .await?
+ .into_inner()
+ .map_err(FlightError::Tonic);
+
+ Ok(response.boxed())
+ }
+
+ /// Make a `DoExchange` call to the server with the provided
+ /// [`Stream`](futures::Stream) of [`FlightData`] and returning a
+ /// stream of [`FlightData`].
+ ///
+ /// # Example:
+ /// ```no_run
+ /// # async fn run() {
+ /// # use futures::{TryStreamExt, StreamExt};
+ /// # use std::sync::Arc;
+ /// # use arrow_array::UInt64Array;
+ /// # use arrow_array::RecordBatch;
+ /// # use arrow_flight::{FlightClient, FlightDescriptor, PutResult};
+ /// # use arrow_flight::encode::FlightDataEncoderBuilder;
+ /// # let batch = RecordBatch::try_from_iter(vec![
+ /// # ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _)
+ /// # ]).unwrap();
+ /// # let channel: tonic::transport::Channel = unimplemented!();
+ /// let mut client = FlightClient::new(channel);
+ ///
+ /// // encode the batch as a stream of `FlightData`
+ /// let flight_data_stream = FlightDataEncoderBuilder::new()
+ /// .build(futures::stream::iter(vec![Ok(batch)]))
+ /// // data encoder return Results, but do_exchange requires FlightData
+ /// .map(|batch|batch.unwrap());
+ ///
+ /// // send the stream and get the results as `RecordBatches`
+ /// let response: Vec<RecordBatch> = client
+ /// .do_exchange(flight_data_stream)
+ /// .await
+ /// .unwrap()
+ /// .try_collect() // use TryStreamExt to collect stream
+ /// .await
+ /// .expect("error calling do_exchange");
+ /// # }
+ /// ```
+ pub async fn do_exchange<S: Stream<Item = FlightData> + Send + 'static>(
+ &mut self,
+ request: S,
+ ) -> Result<FlightRecordBatchStream> {
+ let request = self.make_request(request);
+
+ let response = self
+ .inner
+ .do_exchange(request)
+ .await?
+ .into_inner()
+ .map_err(FlightError::Tonic);
+
+ Ok(FlightRecordBatchStream::new_from_flight_data(response))
+ }
+
+ /// Make a `ListFlights` call to the server with the provided
+ /// critera and returning a [`Stream`](futures::Stream) of [`FlightInfo`].
+ ///
+ /// # Example:
+ /// ```no_run
+ /// # async fn run() {
+ /// # use futures::TryStreamExt;
+ /// # use bytes::Bytes;
+ /// # use arrow_flight::{FlightInfo, FlightClient};
+ /// # let channel: tonic::transport::Channel = unimplemented!();
+ /// let mut client = FlightClient::new(channel);
+ ///
+ /// // Send 'Name=Foo' bytes as the "expression" to the server
+ /// // and gather the returned FlightInfo
+ /// let responses: Vec<FlightInfo> = client
+ /// .list_flights(Bytes::from("Name=Foo"))
+ /// .await
+ /// .expect("error listing flights")
+ /// .try_collect() // use TryStreamExt to collect stream
+ /// .await
+ /// .expect("error gathering flights");
+ /// # }
+ /// ```
+ pub async fn list_flights(
+ &mut self,
+ expression: impl Into<Bytes>,
+ ) -> Result<BoxStream<'static, Result<FlightInfo>>> {
+ let request = Criteria {
+ expression: expression.into(),
+ };
+
+ let request = self.make_request(request);
+
+ let response = self
+ .inner
+ .list_flights(request)
+ .await?
+ .into_inner()
+ .map_err(FlightError::Tonic);
+
+ Ok(response.boxed())
+ }
+
+ /// Make a `GetSchema` call to the server with the provided
+ /// [`FlightDescriptor`] and returning the associated [`Schema`].
+ ///
+ /// # Example:
+ /// ```no_run
+ /// # async fn run() {
+ /// # use bytes::Bytes;
+ /// # use arrow_flight::{FlightDescriptor, FlightClient};
+ /// # use arrow_schema::Schema;
+ /// # let channel: tonic::transport::Channel = unimplemented!();
+ /// let mut client = FlightClient::new(channel);
+ ///
+ /// // Request the schema result of a 'CMD' request to the server
+ /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
+ ///
+ /// let schema: Schema = client
+ /// .get_schema(request)
+ /// .await
+ /// .expect("error making request");
+ /// # }
+ /// ```
+ pub async fn get_schema(
+ &mut self,
+ flight_descriptor: FlightDescriptor,
+ ) -> Result<Schema> {
+ let request = self.make_request(flight_descriptor);
+
+ let schema_result = self.inner.get_schema(request).await?.into_inner();
+
+ // attempt decode from IPC
+ let schema: Schema = schema_result.try_into()?;
+
+ Ok(schema)
+ }
+
+ /// Make a `ListActions` call to the server and returning a
+ /// [`Stream`](futures::Stream) of [`ActionType`].
+ ///
+ /// # Example:
+ /// ```no_run
+ /// # async fn run() {
+ /// # use futures::TryStreamExt;
+ /// # use arrow_flight::{ActionType, FlightClient};
+ /// # use arrow_schema::Schema;
+ /// # let channel: tonic::transport::Channel = unimplemented!();
+ /// let mut client = FlightClient::new(channel);
+ ///
+ /// // List available actions on the server:
+ /// let actions: Vec<ActionType> = client
+ /// .list_actions()
+ /// .await
+ /// .expect("error listing actions")
+ /// .try_collect() // use TryStreamExt to collect stream
+ /// .await
+ /// .expect("error gathering actions");
+ /// # }
+ /// ```
+ pub async fn list_actions(
+ &mut self,
+ ) -> Result<BoxStream<'static, Result<ActionType>>> {
+ let request = self.make_request(Empty {});
+
+ let action_stream = self
+ .inner
+ .list_actions(request)
+ .await?
+ .into_inner()
+ .map_err(FlightError::Tonic);
+
+ Ok(action_stream.boxed())
+ }
+
+ /// Make a `DoAction` call to the server and returning a
+ /// [`Stream`](futures::Stream) of opaque [`Bytes`].
+ ///
+ /// # Example:
+ /// ```no_run
+ /// # async fn run() {
+ /// # use bytes::Bytes;
+ /// # use futures::TryStreamExt;
+ /// # use arrow_flight::{Action, FlightClient};
+ /// # use arrow_schema::Schema;
+ /// # let channel: tonic::transport::Channel = unimplemented!();
+ /// let mut client = FlightClient::new(channel);
+ ///
+ /// let request = Action::new("my_action", "the body");
+ ///
+ /// // Make a request to run the action on the server
+ /// let results: Vec<Bytes> = client
+ /// .do_action(request)
+ /// .await
+ /// .expect("error executing acton")
+ /// .try_collect() // use TryStreamExt to collect stream
+ /// .await
+ /// .expect("error gathering action results");
+ /// # }
+ /// ```
+ pub async fn do_action(
+ &mut self,
+ action: Action,
+ ) -> Result<BoxStream<'static, Result<Bytes>>> {
+ let request = self.make_request(action);
+
+ let result_stream = self
+ .inner
+ .do_action(request)
+ .await?
+ .into_inner()
+ .map_err(FlightError::Tonic)
+ .map(|r| {
+ r.map(|r| {
+ // unwrap inner bytes
+ let crate::Result { body } = r;
+ body
+ })
+ });
+
+ Ok(result_stream.boxed())
+ }
/// return a Request, adding any configured metadata
fn make_request<T>(&self, t: T) -> tonic::Request<T> {
diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs
index c2da58eb5..87aeba1c1 100644
--- a/arrow-flight/src/lib.rs
+++ b/arrow-flight/src/lib.rs
@@ -348,8 +348,7 @@ impl TryFrom<FlightInfo> for Schema {
type Error = ArrowError;
fn try_from(value: FlightInfo) -> ArrowResult<Self> {
- let msg = IpcMessage(value.schema);
- msg.try_into()
+ value.try_decode_schema()
}
}
@@ -368,6 +367,13 @@ impl TryFrom<&SchemaResult> for Schema {
}
}
+impl TryFrom<SchemaResult> for Schema {
+ type Error = ArrowError;
+ fn try_from(data: SchemaResult) -> ArrowResult<Self> {
+ (&data).try_into()
+ }
+}
+
// FlightData, FlightDescriptor, etc..
impl FlightData {
@@ -422,6 +428,12 @@ impl FlightInfo {
total_bytes,
}
}
+
+ /// Try and convert the data in this `FlightInfo` into a [`Schema`]
+ pub fn try_decode_schema(self) -> ArrowResult<Schema> {
+ let msg = IpcMessage(self.schema);
+ msg.try_into()
+ }
}
impl<'a> SchemaAsIpc<'a> {
@@ -432,6 +444,16 @@ impl<'a> SchemaAsIpc<'a> {
}
}
+impl Action {
+ /// Create a new Action with type and body
+ pub fn new(action_type: impl Into<String>, body: impl Into<Bytes>) -> Self
{
+ Self {
+ r#type: action_type.into(),
+ body: body.into(),
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs
index c471294d7..7537e46db 100644
--- a/arrow-flight/tests/client.rs
+++ b/arrow-flight/tests/client.rs
@@ -22,12 +22,13 @@ mod common {
}
use arrow_array::{RecordBatch, UInt64Array};
use arrow_flight::{
- error::FlightError, FlightClient, FlightDescriptor, FlightInfo,
HandshakeRequest,
- HandshakeResponse, Ticket,
+ decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder,
+ error::FlightError, FlightClient, FlightData, FlightDescriptor, FlightInfo,
+ HandshakeRequest, HandshakeResponse, PutResult, Ticket,
};
use bytes::Bytes;
use common::server::TestFlightServer;
-use futures::{Future, TryStreamExt};
+use futures::{Future, StreamExt, TryStreamExt};
use tokio::{net::TcpListener, task::JoinHandle};
use tonic::{
transport::{Channel, Uri},
@@ -41,8 +42,9 @@ const DEFAULT_TIMEOUT_SECONDS: u64 = 30;
#[tokio::test]
async fn test_handshake() {
do_test(|test_server, mut client| async move {
- let request_payload = Bytes::from("foo");
- let response_payload = Bytes::from("Bar");
+ client.add_header("foo-header", "bar-header-value").unwrap();
+ let request_payload = Bytes::from("foo-request-payload");
+ let response_payload = Bytes::from("bar-response-payload");
let request = HandshakeRequest {
payload: request_payload.clone(),
@@ -58,6 +60,7 @@ async fn test_handshake() {
let response = client.handshake(request_payload).await.unwrap();
assert_eq!(response, response_payload);
assert_eq!(test_server.take_handshake_request(), Some(request));
+ ensure_metadata(&client, &test_server);
})
.await;
}
@@ -65,7 +68,7 @@ async fn test_handshake() {
#[tokio::test]
async fn test_handshake_error() {
do_test(|test_server, mut client| async move {
- let request_payload = "foo".to_string().into_bytes();
+ let request_payload = "foo-request-payload".to_string().into_bytes();
let e = Status::unauthenticated("DENIED");
test_server.set_handshake_response(Err(e));
@@ -76,26 +79,6 @@ async fn test_handshake_error() {
.await;
}
-#[tokio::test]
-async fn test_handshake_metadata() {
- do_test(|test_server, mut client| async move {
- client.add_header("foo", "bar").unwrap();
-
- let request_payload = Bytes::from("Blarg");
- let response_payload = Bytes::from("Bazz");
-
- let response = HandshakeResponse {
- payload: response_payload.clone(),
- protocol_version: 0,
- };
-
- test_server.set_handshake_response(Ok(response));
- client.handshake(request_payload).await.unwrap();
- ensure_metadata(&client, &test_server);
- })
- .await;
-}
-
/// Verifies that all headers sent from the the client are in the
request_metadata
fn ensure_metadata(client: &FlightClient, test_server: &TestFlightServer) {
let client_metadata = client.metadata().clone().into_headers();
@@ -130,6 +113,7 @@ fn test_flight_info(request: &FlightDescriptor) ->
FlightInfo {
#[tokio::test]
async fn test_get_flight_info() {
do_test(|test_server, mut client| async move {
+ client.add_header("foo-header", "bar-header-value").unwrap();
let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
let expected_response = test_flight_info(&request);
@@ -139,6 +123,7 @@ async fn test_get_flight_info() {
assert_eq!(response, expected_response);
assert_eq!(test_server.take_get_flight_info_request(), Some(request));
+ ensure_metadata(&client, &test_server);
})
.await;
}
@@ -158,25 +143,12 @@ async fn test_get_flight_info_error() {
.await;
}
-#[tokio::test]
-async fn test_get_flight_info_metadata() {
- do_test(|test_server, mut client| async move {
- client.add_header("foo", "bar").unwrap();
- let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
-
- let expected_response = test_flight_info(&request);
- test_server.set_get_flight_info_response(Ok(expected_response));
- client.get_flight_info(request.clone()).await.unwrap();
- ensure_metadata(&client, &test_server);
- })
- .await;
-}
-
// TODO more negative tests (like if there are endpoints defined, etc)
#[tokio::test]
async fn test_do_get() {
do_test(|test_server, mut client| async move {
+ client.add_header("foo-header", "bar-header-value").unwrap();
let ticket = Ticket {
ticket: Bytes::from("my awesome flight ticket"),
};
@@ -202,6 +174,7 @@ async fn test_do_get() {
assert_eq!(response, expected_response);
assert_eq!(test_server.take_do_get_request(), Some(ticket));
+ ensure_metadata(&client, &test_server);
})
.await;
}
@@ -209,7 +182,7 @@ async fn test_do_get() {
#[tokio::test]
async fn test_do_get_error() {
do_test(|test_server, mut client| async move {
- client.add_header("foo", "bar").unwrap();
+ client.add_header("foo-header", "bar-header-value").unwrap();
let ticket = Ticket {
ticket: Bytes::from("my awesome flight ticket"),
};
@@ -259,6 +232,248 @@ async fn test_do_get_error_in_record_batch_stream() {
.await;
}
+#[tokio::test]
+async fn test_do_put() {
+ do_test(|test_server, mut client| async move {
+ client.add_header("foo-header", "bar-header-value").unwrap();
+
+ // encode the batch as a stream of FlightData
+ let input_flight_data = test_flight_data().await;
+
+ let expected_response = vec![
+ PutResult {
+ app_metadata: Bytes::from("foo-metadata1"),
+ },
+ PutResult {
+ app_metadata: Bytes::from("bar-metadata2"),
+ },
+ ];
+
+ test_server
+
.set_do_put_response(expected_response.clone().into_iter().map(Ok).collect());
+
+ let response_stream = client
+ .do_put(futures::stream::iter(input_flight_data.clone()))
+ .await
+ .expect("error making request");
+
+ let response: Vec<_> = response_stream
+ .try_collect()
+ .await
+ .expect("Error streaming data");
+
+ assert_eq!(response, expected_response);
+ assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
+ ensure_metadata(&client, &test_server);
+ })
+ .await;
+}
+
+#[tokio::test]
+async fn test_do_put_error() {
+ do_test(|test_server, mut client| async move {
+ client.add_header("foo-header", "bar-header-value").unwrap();
+
+ let input_flight_data = test_flight_data().await;
+
+ let response = client
+ .do_put(futures::stream::iter(input_flight_data.clone()))
+ .await;
+ let response = match response {
+ Ok(_) => panic!("unexpected success"),
+ Err(e) => e,
+ };
+
+ let e = Status::internal("No do_put response configured");
+ expect_status(response, e);
+ // server still got the request
+ assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
+ ensure_metadata(&client, &test_server);
+ })
+ .await;
+}
+
+#[tokio::test]
+async fn test_do_put_error_stream() {
+ do_test(|test_server, mut client| async move {
+ client.add_header("foo-header", "bar-header-value").unwrap();
+
+ let input_flight_data = test_flight_data().await;
+
+ let response = vec![
+ Ok(PutResult {
+ app_metadata: Bytes::from("foo-metadata"),
+ }),
+ Err(FlightError::Tonic(Status::invalid_argument("bad arg"))),
+ ];
+
+ test_server.set_do_put_response(response);
+
+ let response_stream = client
+ .do_put(futures::stream::iter(input_flight_data.clone()))
+ .await
+ .expect("error making request");
+
+ let response: Result<Vec<_>, _> = response_stream.try_collect().await;
+ let response = match response {
+ Ok(_) => panic!("unexpected success"),
+ Err(e) => e,
+ };
+
+ let e = Status::invalid_argument("bad arg");
+ expect_status(response, e);
+ // server still got the request
+ assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
+ ensure_metadata(&client, &test_server);
+ })
+ .await;
+}
+
+#[tokio::test]
+async fn test_do_exchange() {
+ do_test(|test_server, mut client| async move {
+ client.add_header("foo-header", "bar-header-value").unwrap();
+
+ // encode the batch as a stream of FlightData
+ let input_flight_data = test_flight_data().await;
+ let output_flight_data = test_flight_data2().await;
+
+ test_server.set_do_exchange_response(
+ output_flight_data.clone().into_iter().map(Ok).collect(),
+ );
+
+ let response_stream = client
+ .do_exchange(futures::stream::iter(input_flight_data.clone()))
+ .await
+ .expect("error making request");
+
+ let response: Vec<_> = response_stream
+ .try_collect()
+ .await
+ .expect("Error streaming data");
+
+ let expected_stream =
futures::stream::iter(output_flight_data).map(Ok);
+
+ let expected_batches: Vec<_> =
+ FlightRecordBatchStream::new_from_flight_data(expected_stream)
+ .try_collect()
+ .await
+ .unwrap();
+
+ assert_eq!(response, expected_batches);
+ assert_eq!(
+ test_server.take_do_exchange_request(),
+ Some(input_flight_data)
+ );
+ ensure_metadata(&client, &test_server);
+ })
+ .await;
+}
+
+#[tokio::test]
+async fn test_do_exchange_error() {
+ do_test(|test_server, mut client| async move {
+ client.add_header("foo-header", "bar-header-value").unwrap();
+
+ let input_flight_data = test_flight_data().await;
+
+ let response = client
+ .do_exchange(futures::stream::iter(input_flight_data.clone()))
+ .await;
+ let response = match response {
+ Ok(_) => panic!("unexpected success"),
+ Err(e) => e,
+ };
+
+ let e = Status::internal("No do_exchange response configured");
+ expect_status(response, e);
+ // server still got the request
+ assert_eq!(
+ test_server.take_do_exchange_request(),
+ Some(input_flight_data)
+ );
+ ensure_metadata(&client, &test_server);
+ })
+ .await;
+}
+
+#[tokio::test]
+async fn test_do_exchange_error_stream() {
+ do_test(|test_server, mut client| async move {
+ client.add_header("foo-header", "bar-header-value").unwrap();
+
+ let input_flight_data = test_flight_data().await;
+
+ let response = test_flight_data2()
+ .await
+ .into_iter()
+ .enumerate()
+ .map(|(i, m)| {
+ if i == 0 {
+ Ok(m)
+ } else {
+ // make all messages after the first an error
+ let e = tonic::Status::invalid_argument("the error");
+ Err(FlightError::Tonic(e))
+ }
+ })
+ .collect();
+
+ test_server.set_do_exchange_response(response);
+
+ let response_stream = client
+ .do_exchange(futures::stream::iter(input_flight_data.clone()))
+ .await
+ .expect("error making request");
+
+ let response: Result<Vec<_>, _> = response_stream.try_collect().await;
+ let response = match response {
+ Ok(_) => panic!("unexpected success"),
+ Err(e) => e,
+ };
+
+ let e = tonic::Status::invalid_argument("the error");
+ expect_status(response, e);
+ // server still got the request
+ assert_eq!(
+ test_server.take_do_exchange_request(),
+ Some(input_flight_data)
+ );
+ ensure_metadata(&client, &test_server);
+ })
+ .await;
+}
+
+async fn test_flight_data() -> Vec<FlightData> {
+ let batch = RecordBatch::try_from_iter(vec![(
+ "col",
+ Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
+ )])
+ .unwrap();
+
+ // encode the batch as a stream of FlightData
+ FlightDataEncoderBuilder::new()
+ .build(futures::stream::iter(vec![Ok(batch)]))
+ .try_collect()
+ .await
+ .unwrap()
+}
+
+async fn test_flight_data2() -> Vec<FlightData> {
+ let batch = RecordBatch::try_from_iter(vec![(
+ "col2",
+ Arc::new(UInt64Array::from_iter([10, 23, 33])) as _,
+ )])
+ .unwrap();
+
+ // encode the batch as a stream of FlightData
+ FlightDataEncoderBuilder::new()
+ .build(futures::stream::iter(vec![Ok(batch)]))
+ .try_collect()
+ .await
+ .unwrap()
+}
+
/// Runs the future returned by the function, passing it a test server and
client
async fn do_test<F, Fut>(f: F)
where
diff --git a/arrow-flight/tests/common/server.rs
b/arrow-flight/tests/common/server.rs
index 45f81b189..5060d9d0c 100644
--- a/arrow-flight/tests/common/server.rs
+++ b/arrow-flight/tests/common/server.rs
@@ -18,7 +18,7 @@
use std::sync::{Arc, Mutex};
use arrow_array::RecordBatch;
-use futures::{stream::BoxStream, TryStreamExt};
+use futures::{stream::BoxStream, StreamExt, TryStreamExt};
use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming};
use arrow_flight::{
@@ -98,6 +98,39 @@ impl TestFlightServer {
.take()
}
+ /// Specify the response returned from the next call to `do_put`
+ pub fn set_do_put_response(&self, response: Vec<Result<PutResult,
FlightError>>) {
+ let mut state = self.state.lock().expect("mutex not poisoned");
+ state.do_put_response.replace(response);
+ }
+
+ /// Take and return last do_put request send to the server,
+ pub fn take_do_put_request(&self) -> Option<Vec<FlightData>> {
+ self.state
+ .lock()
+ .expect("mutex not poisoned")
+ .do_put_request
+ .take()
+ }
+
+ /// Specify the response returned from the next call to `do_exchange`
+ pub fn set_do_exchange_response(
+ &self,
+ response: Vec<Result<FlightData, FlightError>>,
+ ) {
+ let mut state = self.state.lock().expect("mutex not poisoned");
+ state.do_exchange_response.replace(response);
+ }
+
+ /// Take and return last do_exchange request send to the server,
+ pub fn take_do_exchange_request(&self) -> Option<Vec<FlightData>> {
+ self.state
+ .lock()
+ .expect("mutex not poisoned")
+ .do_exchange_request
+ .take()
+ }
+
/// Returns the last metadata from a request received by the server
pub fn take_last_request_metadata(&self) -> Option<MetadataMap> {
self.state
@@ -130,6 +163,14 @@ struct State {
pub do_get_request: Option<Ticket>,
/// The next response returned from `do_get`
pub do_get_response: Option<Vec<Result<RecordBatch, FlightError>>>,
+ /// The last do_put request received
+ pub do_put_request: Option<Vec<FlightData>>,
+ /// The next response returned from `do_put`
+ pub do_put_response: Option<Vec<Result<PutResult, FlightError>>>,
+ /// The last do_exchange request received
+ pub do_exchange_request: Option<Vec<FlightData>>,
+ /// The next response returned from `do_exchange`
+ pub do_exchange_response: Option<Vec<Result<FlightData, FlightError>>>,
/// The last request headers received
pub last_request_metadata: Option<MetadataMap>,
}
@@ -167,7 +208,7 @@ impl FlightService for TestFlightServer {
// turn into a streaming response
let output = futures::stream::iter(std::iter::once(Ok(response)));
- Ok(Response::new(Box::pin(output) as Self::HandshakeStream))
+ Ok(Response::new(output.boxed()))
}
async fn list_flights(
@@ -215,16 +256,30 @@ impl FlightService for TestFlightServer {
let stream = FlightDataEncoderBuilder::new()
.build(batch_stream)
- .map_err(|e| e.into());
+ .map_err(Into::into);
- Ok(Response::new(Box::pin(stream) as _))
+ Ok(Response::new(stream.boxed()))
}
async fn do_put(
&self,
- _request: Request<Streaming<FlightData>>,
+ request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoPutStream>, Status> {
- Err(Status::unimplemented("Implement do_put"))
+ self.save_metadata(&request);
+ let do_put_request: Vec<_> = request.into_inner().try_collect().await?;
+
+ let mut state = self.state.lock().expect("mutex not poisoned");
+
+ state.do_put_request = Some(do_put_request);
+
+ let response = state
+ .do_put_response
+ .take()
+ .ok_or_else(|| Status::internal("No do_put response configured"))?;
+
+ let stream = futures::stream::iter(response).map_err(Into::into);
+
+ Ok(Response::new(stream.boxed()))
}
async fn do_action(
@@ -243,8 +298,22 @@ impl FlightService for TestFlightServer {
async fn do_exchange(
&self,
- _request: Request<Streaming<FlightData>>,
+ request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoExchangeStream>, Status> {
- Err(Status::unimplemented("Implement do_exchange"))
+ self.save_metadata(&request);
+ let do_exchange_request: Vec<_> =
request.into_inner().try_collect().await?;
+
+ let mut state = self.state.lock().expect("mutex not poisoned");
+
+ state.do_exchange_request = Some(do_exchange_request);
+
+ let response = state
+ .do_exchange_response
+ .take()
+ .ok_or_else(|| Status::internal("No do_exchange response
configured"))?;
+
+ let stream = futures::stream::iter(response).map_err(Into::into);
+
+ Ok(Response::new(stream.boxed()))
}
}