This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 3ae1c728b Expose Inner FlightServiceClient on FlightSqlServiceClient 
(#3551) (#3556)
3ae1c728b is described below

commit 3ae1c728b266c1ba801409eb7f4b901285783e94
Author: Raphael Taylor-Davies <1781103+tustv...@users.noreply.github.com>
AuthorDate: Wed Jan 18 17:38:07 2023 +0000

    Expose Inner FlightServiceClient on FlightSqlServiceClient (#3551) (#3556)
    
    * Remove unnecessary Mutex from FlightSqlServiceClient (#3551)
    
    * Add inner and inner_mut
    
    * Add into_inner
---
 arrow-flight/src/sql/client.rs | 53 ++++++++++++++++++++----------------------
 1 file changed, 25 insertions(+), 28 deletions(-)

diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs
index ecc121d98..5c5f84b3d 100644
--- a/arrow-flight/src/sql/client.rs
+++ b/arrow-flight/src/sql/client.rs
@@ -19,7 +19,6 @@ use base64::prelude::BASE64_STANDARD;
 use base64::Engine;
 use bytes::Bytes;
 use std::collections::HashMap;
-use std::sync::Arc;
 use std::time::Duration;
 
 use crate::flight_service_client::FlightServiceClient;
@@ -45,7 +44,6 @@ use arrow_ipc::{root_as_message, MessageHeader};
 use arrow_schema::{ArrowError, Schema, SchemaRef};
 use futures::{stream, TryStreamExt};
 use prost::Message;
-use tokio::sync::{Mutex, MutexGuard};
 #[cfg(feature = "tls")]
 use tonic::transport::{Certificate, ClientTlsConfig, Identity};
 use tonic::transport::{Channel, Endpoint};
@@ -56,7 +54,7 @@ use tonic::Streaming;
 #[derive(Debug, Clone)]
 pub struct FlightSqlServiceClient {
     token: Option<String>,
-    flight_client: Arc<Mutex<FlightServiceClient<Channel>>>,
+    flight_client: FlightServiceClient<Channel>,
 }
 
 /// A FlightSql protocol client that can run queries against FlightSql servers
@@ -124,16 +122,23 @@ impl FlightSqlServiceClient {
         let flight_client = FlightServiceClient::new(channel);
         FlightSqlServiceClient {
             token: None,
-            flight_client: Arc::new(Mutex::new(flight_client)),
+            flight_client,
         }
     }
 
-    fn mut_client(
-        &mut self,
-    ) -> Result<MutexGuard<FlightServiceClient<Channel>>, ArrowError> {
+    /// Return a reference to the underlying [`FlightServiceClient`]
+    pub fn inner(&self) -> &FlightServiceClient<Channel> {
+        &self.flight_client
+    }
+
+    /// Return a mutable reference to the underlying [`FlightServiceClient`]
+    pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
+        &mut self.flight_client
+    }
+
+    /// Consume this client and return the underlying [`FlightServiceClient`]
+    pub fn into_inner(self) -> FlightServiceClient<Channel> {
         self.flight_client
-            .try_lock()
-            .map_err(|_| ArrowError::IoError("Unable to lock 
client".to_string()))
     }
 
     async fn get_flight_info_for_command<M: ProstMessageExt>(
@@ -142,7 +147,7 @@ impl FlightSqlServiceClient {
     ) -> Result<FlightInfo, ArrowError> {
         let descriptor = 
FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
         let fi = self
-            .mut_client()?
+            .flight_client
             .get_flight_info(descriptor)
             .await
             .map_err(status_to_arrow_error)?
@@ -174,7 +179,7 @@ impl FlightSqlServiceClient {
             .map_err(|_| ArrowError::ParseError("Cannot parse 
header".to_string()))?;
         req.metadata_mut().insert("authorization", val);
         let resp = self
-            .mut_client()?
+            .flight_client
             .handshake(req)
             .await
             .map_err(|e| ArrowError::IoError(format!("Can't handshake {}", 
e)))?;
@@ -208,7 +213,7 @@ impl FlightSqlServiceClient {
         let cmd = CommandStatementUpdate { query };
         let descriptor = 
FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
         let mut result = self
-            .mut_client()?
+            .flight_client
             .do_put(stream::iter(vec![FlightData {
                 flight_descriptor: Some(descriptor),
                 ..Default::default()
@@ -247,7 +252,7 @@ impl FlightSqlServiceClient {
         ticket: Ticket,
     ) -> Result<Streaming<FlightData>, ArrowError> {
         Ok(self
-            .mut_client()?
+            .flight_client
             .do_get(ticket)
             .await
             .map_err(status_to_arrow_error)?
@@ -332,7 +337,7 @@ impl FlightSqlServiceClient {
             req.metadata_mut().insert("authorization", val);
         }
         let mut result = self
-            .mut_client()?
+            .flight_client
             .do_action(req)
             .await
             .map_err(status_to_arrow_error)?
@@ -369,7 +374,7 @@ impl FlightSqlServiceClient {
 /// A PreparedStatement
 #[derive(Debug, Clone)]
 pub struct PreparedStatement<T> {
-    flight_client: Arc<Mutex<FlightServiceClient<T>>>,
+    flight_client: FlightServiceClient<T>,
     parameter_binding: Option<RecordBatch>,
     handle: Bytes,
     dataset_schema: Schema,
@@ -378,13 +383,13 @@ pub struct PreparedStatement<T> {
 
 impl PreparedStatement<Channel> {
     pub(crate) fn new(
-        client: Arc<Mutex<FlightServiceClient<Channel>>>,
+        flight_client: FlightServiceClient<Channel>,
         handle: impl Into<Bytes>,
         dataset_schema: Schema,
         parameter_schema: Schema,
     ) -> Self {
         PreparedStatement {
-            flight_client: client,
+            flight_client,
             parameter_binding: None,
             handle: handle.into(),
             dataset_schema,
@@ -399,7 +404,7 @@ impl PreparedStatement<Channel> {
         };
         let descriptor = 
FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
         let result = self
-            .mut_client()?
+            .flight_client
             .get_flight_info(descriptor)
             .await
             .map_err(status_to_arrow_error)?
@@ -414,7 +419,7 @@ impl PreparedStatement<Channel> {
         };
         let descriptor = 
FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
         let mut result = self
-            .mut_client()?
+            .flight_client
             .do_put(stream::iter(vec![FlightData {
                 flight_descriptor: Some(descriptor),
                 ..Default::default()
@@ -463,20 +468,12 @@ impl PreparedStatement<Channel> {
             body: cmd.as_any().encode_to_vec().into(),
         };
         let _ = self
-            .mut_client()?
+            .flight_client
             .do_action(action)
             .await
             .map_err(status_to_arrow_error)?;
         Ok(())
     }
-
-    fn mut_client(
-        &mut self,
-    ) -> Result<MutexGuard<FlightServiceClient<Channel>>, ArrowError> {
-        self.flight_client
-            .try_lock()
-            .map_err(|_| ArrowError::IoError("Unable to lock 
client".to_string()))
-    }
 }
 
 fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError {

Reply via email to