This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 42ffc3f34 Complete mid-level `FlightClient` (#3402)
42ffc3f34 is described below

commit 42ffc3f344b338289d5e8e6b12b247f791dd5d8f
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Jan 5 09:39:28 2023 -0500

    Complete mid-level `FlightClient` (#3402)
    
    * Implement `FlightClient::do_put` and `FlightClient::do_exchange`
    
    * Implement ArrowClient::{list_flights, list_actions, do_action, get_schema}
    
    * Apply suggestions from code review
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
    
    * remove outdated comment
    
    * make foo/bar placeholders in test more specific
    
    * simplify tests
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
---
 arrow-flight/src/client.rs          | 303 +++++++++++++++++++++++++++++++++---
 arrow-flight/src/lib.rs             |  26 +++-
 arrow-flight/tests/client.rs        | 297 ++++++++++++++++++++++++++++++-----
 arrow-flight/tests/common/server.rs |  85 +++++++++-
 4 files changed, 638 insertions(+), 73 deletions(-)

diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index 753c40f2a..bdd51dda4 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -16,11 +16,17 @@
 // under the License.
 
 use crate::{
-    decode::FlightRecordBatchStream, 
flight_service_client::FlightServiceClient,
-    FlightDescriptor, FlightInfo, HandshakeRequest, Ticket,
+    decode::FlightRecordBatchStream, 
flight_service_client::FlightServiceClient, Action,
+    ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
+    HandshakeRequest, PutResult, Ticket,
 };
+use arrow_schema::Schema;
 use bytes::Bytes;
-use futures::{future::ready, stream, StreamExt, TryStreamExt};
+use futures::{
+    future::ready,
+    stream::{self, BoxStream},
+    Stream, StreamExt, TryStreamExt,
+};
 use tonic::{metadata::MetadataMap, transport::Channel};
 
 use crate::error::{FlightError, Result};
@@ -160,6 +166,11 @@ impl FlightClient {
     /// returning a [`FlightRecordBatchStream`] for reading
     /// [`RecordBatch`](arrow_array::RecordBatch)es.
     ///
+    /// # Note
+    ///
+    /// To access the returned [`FlightData`] use
+    /// [`FlightRecordBatchStream::into_inner()`]
+    ///
     /// # Example:
     /// ```no_run
     /// # async fn run() {
@@ -167,12 +178,8 @@ impl FlightClient {
     /// # use arrow_flight::FlightClient;
     /// # use arrow_flight::Ticket;
     /// # use arrow_array::RecordBatch;
-    /// # use tonic::transport::Channel;
     /// # use futures::stream::TryStreamExt;
-    /// # let channel = Channel::from_static("http://localhost:1234";)
-    /// #  .connect()
-    /// #  .await
-    /// #  .expect("error connecting");
+    /// # let channel: tonic::transport::Channel = unimplemented!();
     /// # let ticket = Ticket { ticket: Bytes::from("foo") };
     /// let mut client = FlightClient::new(channel);
     ///
@@ -199,8 +206,7 @@ impl FlightClient {
             .do_get(request)
             .await?
             .into_inner()
-            // convert to FlightError
-            .map_err(|e| e.into());
+            .map_err(FlightError::Tonic);
 
         Ok(FlightRecordBatchStream::new_from_flight_data(
             response_stream,
@@ -217,11 +223,7 @@ impl FlightClient {
     /// # async fn run() {
     /// # use arrow_flight::FlightClient;
     /// # use arrow_flight::FlightDescriptor;
-    /// # use tonic::transport::Channel;
-    /// # let channel = Channel::from_static("http://localhost:1234";)
-    /// #   .connect()
-    /// #   .await
-    /// #   .expect("error connecting");
+    /// # let channel: tonic::transport::Channel = unimplemented!();
     /// let mut client = FlightClient::new(channel);
     ///
     /// // Send a 'CMD' request to the server
@@ -256,13 +258,270 @@ impl FlightClient {
         Ok(response)
     }
 
-    // TODO other methods
-    // list_flights
-    // get_schema
-    // do_put
-    // do_action
-    // list_actions
-    // do_exchange
+    /// Make a `DoPut` call to the server with the provided
+    /// [`Stream`](futures::Stream) of [`FlightData`] and returning a
+    /// stream of [`PutResult`].
+    ///
+    /// # Example:
+    /// ```no_run
+    /// # async fn run() {
+    /// # use futures::{TryStreamExt, StreamExt};
+    /// # use std::sync::Arc;
+    /// # use arrow_array::UInt64Array;
+    /// # use arrow_array::RecordBatch;
+    /// # use arrow_flight::{FlightClient, FlightDescriptor, PutResult};
+    /// # use arrow_flight::encode::FlightDataEncoderBuilder;
+    /// # let batch = RecordBatch::try_from_iter(vec![
+    /// #  ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _)
+    /// # ]).unwrap();
+    /// # let channel: tonic::transport::Channel = unimplemented!();
+    /// let mut client = FlightClient::new(channel);
+    ///
+    /// // encode the batch as a stream of `FlightData`
+    /// let flight_data_stream = FlightDataEncoderBuilder::new()
+    ///   .build(futures::stream::iter(vec![Ok(batch)]))
+    ///   // data encoder return Results, but do_put requires FlightData
+    ///   .map(|batch|batch.unwrap());
+    ///
+    /// // send the stream and get the results as `PutResult`
+    /// let response: Vec<PutResult>= client
+    ///   .do_put(flight_data_stream)
+    ///   .await
+    ///   .unwrap()
+    ///   .try_collect() // use TryStreamExt to collect stream
+    ///   .await
+    ///   .expect("error calling do_put");
+    /// # }
+    /// ```
+    pub async fn do_put<S: Stream<Item = FlightData> + Send + 'static>(
+        &mut self,
+        request: S,
+    ) -> Result<BoxStream<'static, Result<PutResult>>> {
+        let request = self.make_request(request);
+
+        let response = self
+            .inner
+            .do_put(request)
+            .await?
+            .into_inner()
+            .map_err(FlightError::Tonic);
+
+        Ok(response.boxed())
+    }
+
+    /// Make a `DoExchange` call to the server with the provided
+    /// [`Stream`](futures::Stream) of [`FlightData`] and returning a
+    /// stream of [`FlightData`].
+    ///
+    /// # Example:
+    /// ```no_run
+    /// # async fn run() {
+    /// # use futures::{TryStreamExt, StreamExt};
+    /// # use std::sync::Arc;
+    /// # use arrow_array::UInt64Array;
+    /// # use arrow_array::RecordBatch;
+    /// # use arrow_flight::{FlightClient, FlightDescriptor, PutResult};
+    /// # use arrow_flight::encode::FlightDataEncoderBuilder;
+    /// # let batch = RecordBatch::try_from_iter(vec![
+    /// #  ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _)
+    /// # ]).unwrap();
+    /// # let channel: tonic::transport::Channel = unimplemented!();
+    /// let mut client = FlightClient::new(channel);
+    ///
+    /// // encode the batch as a stream of `FlightData`
+    /// let flight_data_stream = FlightDataEncoderBuilder::new()
+    ///   .build(futures::stream::iter(vec![Ok(batch)]))
+    ///   // data encoder return Results, but do_exchange requires FlightData
+    ///   .map(|batch|batch.unwrap());
+    ///
+    /// // send the stream and get the results as `RecordBatches`
+    /// let response: Vec<RecordBatch> = client
+    ///   .do_exchange(flight_data_stream)
+    ///   .await
+    ///   .unwrap()
+    ///   .try_collect() // use TryStreamExt to collect stream
+    ///   .await
+    ///   .expect("error calling do_exchange");
+    /// # }
+    /// ```
+    pub async fn do_exchange<S: Stream<Item = FlightData> + Send + 'static>(
+        &mut self,
+        request: S,
+    ) -> Result<FlightRecordBatchStream> {
+        let request = self.make_request(request);
+
+        let response = self
+            .inner
+            .do_exchange(request)
+            .await?
+            .into_inner()
+            .map_err(FlightError::Tonic);
+
+        Ok(FlightRecordBatchStream::new_from_flight_data(response))
+    }
+
+    /// Make a `ListFlights` call to the server with the provided
+    /// critera and returning a [`Stream`](futures::Stream) of [`FlightInfo`].
+    ///
+    /// # Example:
+    /// ```no_run
+    /// # async fn run() {
+    /// # use futures::TryStreamExt;
+    /// # use bytes::Bytes;
+    /// # use arrow_flight::{FlightInfo, FlightClient};
+    /// # let channel: tonic::transport::Channel = unimplemented!();
+    /// let mut client = FlightClient::new(channel);
+    ///
+    /// // Send 'Name=Foo' bytes as the "expression" to the server
+    /// // and gather the returned FlightInfo
+    /// let responses: Vec<FlightInfo> = client
+    ///   .list_flights(Bytes::from("Name=Foo"))
+    ///   .await
+    ///   .expect("error listing flights")
+    ///   .try_collect() // use TryStreamExt to collect stream
+    ///   .await
+    ///   .expect("error gathering flights");
+    /// # }
+    /// ```
+    pub async fn list_flights(
+        &mut self,
+        expression: impl Into<Bytes>,
+    ) -> Result<BoxStream<'static, Result<FlightInfo>>> {
+        let request = Criteria {
+            expression: expression.into(),
+        };
+
+        let request = self.make_request(request);
+
+        let response = self
+            .inner
+            .list_flights(request)
+            .await?
+            .into_inner()
+            .map_err(FlightError::Tonic);
+
+        Ok(response.boxed())
+    }
+
+    /// Make a `GetSchema` call to the server with the provided
+    /// [`FlightDescriptor`] and returning the associated [`Schema`].
+    ///
+    /// # Example:
+    /// ```no_run
+    /// # async fn run() {
+    /// # use bytes::Bytes;
+    /// # use arrow_flight::{FlightDescriptor, FlightClient};
+    /// # use arrow_schema::Schema;
+    /// # let channel: tonic::transport::Channel = unimplemented!();
+    /// let mut client = FlightClient::new(channel);
+    ///
+    /// // Request the schema result of a 'CMD' request to the server
+    /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
+    ///
+    /// let schema: Schema = client
+    ///   .get_schema(request)
+    ///   .await
+    ///   .expect("error making request");
+    /// # }
+    /// ```
+    pub async fn get_schema(
+        &mut self,
+        flight_descriptor: FlightDescriptor,
+    ) -> Result<Schema> {
+        let request = self.make_request(flight_descriptor);
+
+        let schema_result = self.inner.get_schema(request).await?.into_inner();
+
+        // attempt decode from IPC
+        let schema: Schema = schema_result.try_into()?;
+
+        Ok(schema)
+    }
+
+    /// Make a `ListActions` call to the server and returning a
+    /// [`Stream`](futures::Stream) of [`ActionType`].
+    ///
+    /// # Example:
+    /// ```no_run
+    /// # async fn run() {
+    /// # use futures::TryStreamExt;
+    /// # use arrow_flight::{ActionType, FlightClient};
+    /// # use arrow_schema::Schema;
+    /// # let channel: tonic::transport::Channel = unimplemented!();
+    /// let mut client = FlightClient::new(channel);
+    ///
+    /// // List available actions on the server:
+    /// let actions: Vec<ActionType> = client
+    ///   .list_actions()
+    ///   .await
+    ///   .expect("error listing actions")
+    ///   .try_collect() // use TryStreamExt to collect stream
+    ///   .await
+    ///   .expect("error gathering actions");
+    /// # }
+    /// ```
+    pub async fn list_actions(
+        &mut self,
+    ) -> Result<BoxStream<'static, Result<ActionType>>> {
+        let request = self.make_request(Empty {});
+
+        let action_stream = self
+            .inner
+            .list_actions(request)
+            .await?
+            .into_inner()
+            .map_err(FlightError::Tonic);
+
+        Ok(action_stream.boxed())
+    }
+
+    /// Make a `DoAction` call to the server and returning a
+    /// [`Stream`](futures::Stream) of opaque [`Bytes`].
+    ///
+    /// # Example:
+    /// ```no_run
+    /// # async fn run() {
+    /// # use bytes::Bytes;
+    /// # use futures::TryStreamExt;
+    /// # use arrow_flight::{Action, FlightClient};
+    /// # use arrow_schema::Schema;
+    /// # let channel: tonic::transport::Channel = unimplemented!();
+    /// let mut client = FlightClient::new(channel);
+    ///
+    /// let request = Action::new("my_action", "the body");
+    ///
+    /// // Make a request to run the action on the server
+    /// let results: Vec<Bytes> = client
+    ///   .do_action(request)
+    ///   .await
+    ///   .expect("error executing acton")
+    ///   .try_collect() // use TryStreamExt to collect stream
+    ///   .await
+    ///   .expect("error gathering action results");
+    /// # }
+    /// ```
+    pub async fn do_action(
+        &mut self,
+        action: Action,
+    ) -> Result<BoxStream<'static, Result<Bytes>>> {
+        let request = self.make_request(action);
+
+        let result_stream = self
+            .inner
+            .do_action(request)
+            .await?
+            .into_inner()
+            .map_err(FlightError::Tonic)
+            .map(|r| {
+                r.map(|r| {
+                    // unwrap inner bytes
+                    let crate::Result { body } = r;
+                    body
+                })
+            });
+
+        Ok(result_stream.boxed())
+    }
 
     /// return a Request, adding any configured metadata
     fn make_request<T>(&self, t: T) -> tonic::Request<T> {
diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs
index c2da58eb5..87aeba1c1 100644
--- a/arrow-flight/src/lib.rs
+++ b/arrow-flight/src/lib.rs
@@ -348,8 +348,7 @@ impl TryFrom<FlightInfo> for Schema {
     type Error = ArrowError;
 
     fn try_from(value: FlightInfo) -> ArrowResult<Self> {
-        let msg = IpcMessage(value.schema);
-        msg.try_into()
+        value.try_decode_schema()
     }
 }
 
@@ -368,6 +367,13 @@ impl TryFrom<&SchemaResult> for Schema {
     }
 }
 
+impl TryFrom<SchemaResult> for Schema {
+    type Error = ArrowError;
+    fn try_from(data: SchemaResult) -> ArrowResult<Self> {
+        (&data).try_into()
+    }
+}
+
 // FlightData, FlightDescriptor, etc..
 
 impl FlightData {
@@ -422,6 +428,12 @@ impl FlightInfo {
             total_bytes,
         }
     }
+
+    /// Try and convert the data in this  `FlightInfo` into a [`Schema`]
+    pub fn try_decode_schema(self) -> ArrowResult<Schema> {
+        let msg = IpcMessage(self.schema);
+        msg.try_into()
+    }
 }
 
 impl<'a> SchemaAsIpc<'a> {
@@ -432,6 +444,16 @@ impl<'a> SchemaAsIpc<'a> {
     }
 }
 
+impl Action {
+    /// Create a new Action with type and body
+    pub fn new(action_type: impl Into<String>, body: impl Into<Bytes>) -> Self 
{
+        Self {
+            r#type: action_type.into(),
+            body: body.into(),
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs
index c471294d7..7537e46db 100644
--- a/arrow-flight/tests/client.rs
+++ b/arrow-flight/tests/client.rs
@@ -22,12 +22,13 @@ mod common {
 }
 use arrow_array::{RecordBatch, UInt64Array};
 use arrow_flight::{
-    error::FlightError, FlightClient, FlightDescriptor, FlightInfo, 
HandshakeRequest,
-    HandshakeResponse, Ticket,
+    decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder,
+    error::FlightError, FlightClient, FlightData, FlightDescriptor, FlightInfo,
+    HandshakeRequest, HandshakeResponse, PutResult, Ticket,
 };
 use bytes::Bytes;
 use common::server::TestFlightServer;
-use futures::{Future, TryStreamExt};
+use futures::{Future, StreamExt, TryStreamExt};
 use tokio::{net::TcpListener, task::JoinHandle};
 use tonic::{
     transport::{Channel, Uri},
@@ -41,8 +42,9 @@ const DEFAULT_TIMEOUT_SECONDS: u64 = 30;
 #[tokio::test]
 async fn test_handshake() {
     do_test(|test_server, mut client| async move {
-        let request_payload = Bytes::from("foo");
-        let response_payload = Bytes::from("Bar");
+        client.add_header("foo-header", "bar-header-value").unwrap();
+        let request_payload = Bytes::from("foo-request-payload");
+        let response_payload = Bytes::from("bar-response-payload");
 
         let request = HandshakeRequest {
             payload: request_payload.clone(),
@@ -58,6 +60,7 @@ async fn test_handshake() {
         let response = client.handshake(request_payload).await.unwrap();
         assert_eq!(response, response_payload);
         assert_eq!(test_server.take_handshake_request(), Some(request));
+        ensure_metadata(&client, &test_server);
     })
     .await;
 }
@@ -65,7 +68,7 @@ async fn test_handshake() {
 #[tokio::test]
 async fn test_handshake_error() {
     do_test(|test_server, mut client| async move {
-        let request_payload = "foo".to_string().into_bytes();
+        let request_payload = "foo-request-payload".to_string().into_bytes();
         let e = Status::unauthenticated("DENIED");
         test_server.set_handshake_response(Err(e));
 
@@ -76,26 +79,6 @@ async fn test_handshake_error() {
     .await;
 }
 
-#[tokio::test]
-async fn test_handshake_metadata() {
-    do_test(|test_server, mut client| async move {
-        client.add_header("foo", "bar").unwrap();
-
-        let request_payload = Bytes::from("Blarg");
-        let response_payload = Bytes::from("Bazz");
-
-        let response = HandshakeResponse {
-            payload: response_payload.clone(),
-            protocol_version: 0,
-        };
-
-        test_server.set_handshake_response(Ok(response));
-        client.handshake(request_payload).await.unwrap();
-        ensure_metadata(&client, &test_server);
-    })
-    .await;
-}
-
 /// Verifies that all headers sent from the the client are in the 
request_metadata
 fn ensure_metadata(client: &FlightClient, test_server: &TestFlightServer) {
     let client_metadata = client.metadata().clone().into_headers();
@@ -130,6 +113,7 @@ fn test_flight_info(request: &FlightDescriptor) -> 
FlightInfo {
 #[tokio::test]
 async fn test_get_flight_info() {
     do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
         let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
 
         let expected_response = test_flight_info(&request);
@@ -139,6 +123,7 @@ async fn test_get_flight_info() {
 
         assert_eq!(response, expected_response);
         assert_eq!(test_server.take_get_flight_info_request(), Some(request));
+        ensure_metadata(&client, &test_server);
     })
     .await;
 }
@@ -158,25 +143,12 @@ async fn test_get_flight_info_error() {
     .await;
 }
 
-#[tokio::test]
-async fn test_get_flight_info_metadata() {
-    do_test(|test_server, mut client| async move {
-        client.add_header("foo", "bar").unwrap();
-        let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
-
-        let expected_response = test_flight_info(&request);
-        test_server.set_get_flight_info_response(Ok(expected_response));
-        client.get_flight_info(request.clone()).await.unwrap();
-        ensure_metadata(&client, &test_server);
-    })
-    .await;
-}
-
 // TODO more negative  tests (like if there are endpoints defined, etc)
 
 #[tokio::test]
 async fn test_do_get() {
     do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
         let ticket = Ticket {
             ticket: Bytes::from("my awesome flight ticket"),
         };
@@ -202,6 +174,7 @@ async fn test_do_get() {
 
         assert_eq!(response, expected_response);
         assert_eq!(test_server.take_do_get_request(), Some(ticket));
+        ensure_metadata(&client, &test_server);
     })
     .await;
 }
@@ -209,7 +182,7 @@ async fn test_do_get() {
 #[tokio::test]
 async fn test_do_get_error() {
     do_test(|test_server, mut client| async move {
-        client.add_header("foo", "bar").unwrap();
+        client.add_header("foo-header", "bar-header-value").unwrap();
         let ticket = Ticket {
             ticket: Bytes::from("my awesome flight ticket"),
         };
@@ -259,6 +232,248 @@ async fn test_do_get_error_in_record_batch_stream() {
     .await;
 }
 
+#[tokio::test]
+async fn test_do_put() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        // encode the batch as a stream of FlightData
+        let input_flight_data = test_flight_data().await;
+
+        let expected_response = vec![
+            PutResult {
+                app_metadata: Bytes::from("foo-metadata1"),
+            },
+            PutResult {
+                app_metadata: Bytes::from("bar-metadata2"),
+            },
+        ];
+
+        test_server
+            
.set_do_put_response(expected_response.clone().into_iter().map(Ok).collect());
+
+        let response_stream = client
+            .do_put(futures::stream::iter(input_flight_data.clone()))
+            .await
+            .expect("error making request");
+
+        let response: Vec<_> = response_stream
+            .try_collect()
+            .await
+            .expect("Error streaming data");
+
+        assert_eq!(response, expected_response);
+        assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_do_put_error() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let input_flight_data = test_flight_data().await;
+
+        let response = client
+            .do_put(futures::stream::iter(input_flight_data.clone()))
+            .await;
+        let response = match response {
+            Ok(_) => panic!("unexpected success"),
+            Err(e) => e,
+        };
+
+        let e = Status::internal("No do_put response configured");
+        expect_status(response, e);
+        // server still got the request
+        assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_do_put_error_stream() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let input_flight_data = test_flight_data().await;
+
+        let response = vec![
+            Ok(PutResult {
+                app_metadata: Bytes::from("foo-metadata"),
+            }),
+            Err(FlightError::Tonic(Status::invalid_argument("bad arg"))),
+        ];
+
+        test_server.set_do_put_response(response);
+
+        let response_stream = client
+            .do_put(futures::stream::iter(input_flight_data.clone()))
+            .await
+            .expect("error making request");
+
+        let response: Result<Vec<_>, _> = response_stream.try_collect().await;
+        let response = match response {
+            Ok(_) => panic!("unexpected success"),
+            Err(e) => e,
+        };
+
+        let e = Status::invalid_argument("bad arg");
+        expect_status(response, e);
+        // server still got the request
+        assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_do_exchange() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        // encode the batch as a stream of FlightData
+        let input_flight_data = test_flight_data().await;
+        let output_flight_data = test_flight_data2().await;
+
+        test_server.set_do_exchange_response(
+            output_flight_data.clone().into_iter().map(Ok).collect(),
+        );
+
+        let response_stream = client
+            .do_exchange(futures::stream::iter(input_flight_data.clone()))
+            .await
+            .expect("error making request");
+
+        let response: Vec<_> = response_stream
+            .try_collect()
+            .await
+            .expect("Error streaming data");
+
+        let expected_stream = 
futures::stream::iter(output_flight_data).map(Ok);
+
+        let expected_batches: Vec<_> =
+            FlightRecordBatchStream::new_from_flight_data(expected_stream)
+                .try_collect()
+                .await
+                .unwrap();
+
+        assert_eq!(response, expected_batches);
+        assert_eq!(
+            test_server.take_do_exchange_request(),
+            Some(input_flight_data)
+        );
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_do_exchange_error() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let input_flight_data = test_flight_data().await;
+
+        let response = client
+            .do_exchange(futures::stream::iter(input_flight_data.clone()))
+            .await;
+        let response = match response {
+            Ok(_) => panic!("unexpected success"),
+            Err(e) => e,
+        };
+
+        let e = Status::internal("No do_exchange response configured");
+        expect_status(response, e);
+        // server still got the request
+        assert_eq!(
+            test_server.take_do_exchange_request(),
+            Some(input_flight_data)
+        );
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_do_exchange_error_stream() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let input_flight_data = test_flight_data().await;
+
+        let response = test_flight_data2()
+            .await
+            .into_iter()
+            .enumerate()
+            .map(|(i, m)| {
+                if i == 0 {
+                    Ok(m)
+                } else {
+                    // make all messages after the first an error
+                    let e = tonic::Status::invalid_argument("the error");
+                    Err(FlightError::Tonic(e))
+                }
+            })
+            .collect();
+
+        test_server.set_do_exchange_response(response);
+
+        let response_stream = client
+            .do_exchange(futures::stream::iter(input_flight_data.clone()))
+            .await
+            .expect("error making request");
+
+        let response: Result<Vec<_>, _> = response_stream.try_collect().await;
+        let response = match response {
+            Ok(_) => panic!("unexpected success"),
+            Err(e) => e,
+        };
+
+        let e = tonic::Status::invalid_argument("the error");
+        expect_status(response, e);
+        // server still got the request
+        assert_eq!(
+            test_server.take_do_exchange_request(),
+            Some(input_flight_data)
+        );
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+async fn test_flight_data() -> Vec<FlightData> {
+    let batch = RecordBatch::try_from_iter(vec![(
+        "col",
+        Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
+    )])
+    .unwrap();
+
+    // encode the batch as a stream of FlightData
+    FlightDataEncoderBuilder::new()
+        .build(futures::stream::iter(vec![Ok(batch)]))
+        .try_collect()
+        .await
+        .unwrap()
+}
+
+async fn test_flight_data2() -> Vec<FlightData> {
+    let batch = RecordBatch::try_from_iter(vec![(
+        "col2",
+        Arc::new(UInt64Array::from_iter([10, 23, 33])) as _,
+    )])
+    .unwrap();
+
+    // encode the batch as a stream of FlightData
+    FlightDataEncoderBuilder::new()
+        .build(futures::stream::iter(vec![Ok(batch)]))
+        .try_collect()
+        .await
+        .unwrap()
+}
+
 /// Runs the future returned by the function,  passing it a test server and 
client
 async fn do_test<F, Fut>(f: F)
 where
diff --git a/arrow-flight/tests/common/server.rs 
b/arrow-flight/tests/common/server.rs
index 45f81b189..5060d9d0c 100644
--- a/arrow-flight/tests/common/server.rs
+++ b/arrow-flight/tests/common/server.rs
@@ -18,7 +18,7 @@
 use std::sync::{Arc, Mutex};
 
 use arrow_array::RecordBatch;
-use futures::{stream::BoxStream, TryStreamExt};
+use futures::{stream::BoxStream, StreamExt, TryStreamExt};
 use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming};
 
 use arrow_flight::{
@@ -98,6 +98,39 @@ impl TestFlightServer {
             .take()
     }
 
+    /// Specify the response returned from the next call to `do_put`
+    pub fn set_do_put_response(&self, response: Vec<Result<PutResult, 
FlightError>>) {
+        let mut state = self.state.lock().expect("mutex not poisoned");
+        state.do_put_response.replace(response);
+    }
+
+    /// Take and return last do_put request send to the server,
+    pub fn take_do_put_request(&self) -> Option<Vec<FlightData>> {
+        self.state
+            .lock()
+            .expect("mutex not poisoned")
+            .do_put_request
+            .take()
+    }
+
+    /// Specify the response returned from the next call to `do_exchange`
+    pub fn set_do_exchange_response(
+        &self,
+        response: Vec<Result<FlightData, FlightError>>,
+    ) {
+        let mut state = self.state.lock().expect("mutex not poisoned");
+        state.do_exchange_response.replace(response);
+    }
+
+    /// Take and return last do_exchange request send to the server,
+    pub fn take_do_exchange_request(&self) -> Option<Vec<FlightData>> {
+        self.state
+            .lock()
+            .expect("mutex not poisoned")
+            .do_exchange_request
+            .take()
+    }
+
     /// Returns the last metadata from a request received by the server
     pub fn take_last_request_metadata(&self) -> Option<MetadataMap> {
         self.state
@@ -130,6 +163,14 @@ struct State {
     pub do_get_request: Option<Ticket>,
     /// The next response returned from `do_get`
     pub do_get_response: Option<Vec<Result<RecordBatch, FlightError>>>,
+    /// The last do_put request received
+    pub do_put_request: Option<Vec<FlightData>>,
+    /// The next response returned from `do_put`
+    pub do_put_response: Option<Vec<Result<PutResult, FlightError>>>,
+    /// The last do_exchange request received
+    pub do_exchange_request: Option<Vec<FlightData>>,
+    /// The next response returned from `do_exchange`
+    pub do_exchange_response: Option<Vec<Result<FlightData, FlightError>>>,
     /// The last request headers received
     pub last_request_metadata: Option<MetadataMap>,
 }
@@ -167,7 +208,7 @@ impl FlightService for TestFlightServer {
 
         // turn into a streaming response
         let output = futures::stream::iter(std::iter::once(Ok(response)));
-        Ok(Response::new(Box::pin(output) as Self::HandshakeStream))
+        Ok(Response::new(output.boxed()))
     }
 
     async fn list_flights(
@@ -215,16 +256,30 @@ impl FlightService for TestFlightServer {
 
         let stream = FlightDataEncoderBuilder::new()
             .build(batch_stream)
-            .map_err(|e| e.into());
+            .map_err(Into::into);
 
-        Ok(Response::new(Box::pin(stream) as _))
+        Ok(Response::new(stream.boxed()))
     }
 
     async fn do_put(
         &self,
-        _request: Request<Streaming<FlightData>>,
+        request: Request<Streaming<FlightData>>,
     ) -> Result<Response<Self::DoPutStream>, Status> {
-        Err(Status::unimplemented("Implement do_put"))
+        self.save_metadata(&request);
+        let do_put_request: Vec<_> = request.into_inner().try_collect().await?;
+
+        let mut state = self.state.lock().expect("mutex not poisoned");
+
+        state.do_put_request = Some(do_put_request);
+
+        let response = state
+            .do_put_response
+            .take()
+            .ok_or_else(|| Status::internal("No do_put response configured"))?;
+
+        let stream = futures::stream::iter(response).map_err(Into::into);
+
+        Ok(Response::new(stream.boxed()))
     }
 
     async fn do_action(
@@ -243,8 +298,22 @@ impl FlightService for TestFlightServer {
 
     async fn do_exchange(
         &self,
-        _request: Request<Streaming<FlightData>>,
+        request: Request<Streaming<FlightData>>,
     ) -> Result<Response<Self::DoExchangeStream>, Status> {
-        Err(Status::unimplemented("Implement do_exchange"))
+        self.save_metadata(&request);
+        let do_exchange_request: Vec<_> = 
request.into_inner().try_collect().await?;
+
+        let mut state = self.state.lock().expect("mutex not poisoned");
+
+        state.do_exchange_request = Some(do_exchange_request);
+
+        let response = state
+            .do_exchange_response
+            .take()
+            .ok_or_else(|| Status::internal("No do_exchange response 
configured"))?;
+
+        let stream = futures::stream::iter(response).map_err(Into::into);
+
+        Ok(Response::new(stream.boxed()))
     }
 }

Reply via email to