This is an automated email from the ASF dual-hosted git repository. tustvold pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push: new b717b3939 Better flight SQL example codes (#4144) b717b3939 is described below commit b717b39393367d1de7577078c13b91c59a62d581 Author: sundyli <543950...@qq.com> AuthorDate: Fri Apr 28 02:47:17 2023 -0700 Better flight SQL example codes (#4144) * Better flight sql example codes * Better flight sql example codes * feat: flight sql server enable tcp no deplay * Remove unnecessary doc --------- Co-authored-by: Raphael Taylor-Davies <r.taylordav...@googlemail.com> --- arrow-flight/examples/flight_sql_server.rs | 196 ++++++++++++++++------------- 1 file changed, 107 insertions(+), 89 deletions(-) diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index 43154420d..23d71090a 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -546,8 +546,7 @@ impl ProstMessageExt for FetchResults { #[cfg(test)] mod tests { use super::*; - use futures::future::BoxFuture; - use futures::{FutureExt, TryStreamExt}; + use futures::TryStreamExt; use std::fs; use std::future::Future; use std::net::SocketAddr; @@ -571,42 +570,6 @@ mod tests { (incoming, addr) } - async fn client_with_uds(path: String) -> FlightSqlServiceClient<Channel> { - let connector = service_fn(move |_| UnixStream::connect(path.clone())); - let channel = Endpoint::try_from("http://example.com") - .unwrap() - .connect_with_connector(connector) - .await - .unwrap(); - FlightSqlServiceClient::new(channel) - } - - type ServeFut = BoxFuture<'static, Result<(), tonic::transport::Error>>; - - async fn create_https_server( - ) -> Result<(ServeFut, SocketAddr), tonic::transport::Error> { - let cert = std::fs::read_to_string("examples/data/server.pem").unwrap(); - let key = std::fs::read_to_string("examples/data/server.key").unwrap(); - let client_ca = std::fs::read_to_string("examples/data/client_ca.pem").unwrap(); - - let tls_config = ServerTlsConfig::new() - .identity(Identity::from_pem(&cert, &key)) - .client_ca_root(Certificate::from_pem(&client_ca)); - - let (incoming, addr) = bind_tcp().await; - - let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); - - let serve = Server::builder() - .tls_config(tls_config) - .unwrap() - .add_service(svc) - .serve_with_incoming(incoming) - .boxed(); - - Ok((serve, addr)) - } - fn endpoint(uri: String) -> Result<Endpoint, ArrowError> { let endpoint = Endpoint::new(uri) .map_err(|_| ArrowError::IoError("Cannot create endpoint".to_string()))? @@ -621,56 +584,12 @@ mod tests { Ok(endpoint) } - #[tokio::test] - async fn test_select_https() { - let (serve, addr) = create_https_server().await.unwrap(); - let uri = format!("https://{}:{}", addr.ip(), addr.port()); - - let request_future = async { - let cert = std::fs::read_to_string("examples/data/client1.pem").unwrap(); - let key = std::fs::read_to_string("examples/data/client1.key").unwrap(); - let server_ca = std::fs::read_to_string("examples/data/ca.pem").unwrap(); - - let tls_config = ClientTlsConfig::new() - .domain_name("localhost") - .ca_certificate(Certificate::from_pem(&server_ca)) - .identity(Identity::from_pem(cert, key)); - let endpoint = endpoint(uri).unwrap().tls_config(tls_config).unwrap(); - let channel = endpoint.connect().await.unwrap(); - let mut client = FlightSqlServiceClient::new(channel); - let token = client.handshake("admin", "password").await.unwrap(); - client.set_token(String::from_utf8(token.to_vec()).unwrap()); - println!("Auth succeeded with token: {:?}", token); - let mut stmt = client.prepare("select 1;".to_string()).await.unwrap(); - let flight_info = stmt.execute().await.unwrap(); - let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone(); - let flight_data = client.do_get(ticket).await.unwrap(); - let flight_data: Vec<FlightData> = flight_data.try_collect().await.unwrap(); - let batches = flight_data_to_batches(&flight_data).unwrap(); - let res = pretty_format_batches(batches.as_slice()).unwrap(); - let expected = r#" -+-------------------+ -| salutation | -+-------------------+ -| Hello, FlightSQL! | -+-------------------+"# - .trim() - .to_string(); - assert_eq!(res.to_string(), expected); - }; - - tokio::select! { - _ = serve => panic!("server finished"), - _ = request_future => println!("Client finished!"), - } - } - async fn auth_client(client: &mut FlightSqlServiceClient<Channel>) { let token = client.handshake("admin", "password").await.unwrap(); client.set_token(String::from_utf8(token.to_vec()).unwrap()); } - async fn test_client<F, C>(f: F) + async fn test_uds_client<F, C>(f: F) where F: FnOnce(FlightSqlServiceClient<Channel>) -> C, C: Future<Output = ()>, @@ -682,14 +601,91 @@ mod tests { let uds = UnixListener::bind(path.clone()).unwrap(); let stream = UnixListenerStream::new(uds); - // We would just listen on TCP, but it seems impossible to know when tonic is ready to serve let service = FlightSqlServiceImpl {}; let serve_future = Server::builder() .add_service(FlightServiceServer::new(service)) .serve_with_incoming(stream); let request_future = async { - let client = client_with_uds(path).await; + let connector = service_fn(move |_| UnixStream::connect(path.clone())); + let channel = Endpoint::try_from("http://example.com") + .unwrap() + .connect_with_connector(connector) + .await + .unwrap(); + let client = FlightSqlServiceClient::new(channel); + f(client).await + }; + + tokio::select! { + _ = serve_future => panic!("server returned first"), + _ = request_future => println!("Client finished!"), + } + } + + async fn test_http_client<F, C>(f: F) + where + F: FnOnce(FlightSqlServiceClient<Channel>) -> C, + C: Future<Output = ()>, + { + let (incoming, addr) = bind_tcp().await; + let uri = format!("http://{}:{}", addr.ip(), addr.port()); + + let service = FlightSqlServiceImpl {}; + let serve_future = Server::builder() + .add_service(FlightServiceServer::new(service)) + .serve_with_incoming(incoming); + + let request_future = async { + let endpoint = endpoint(uri).unwrap(); + let channel = endpoint.connect().await.unwrap(); + let client = FlightSqlServiceClient::new(channel); + f(client).await + }; + + tokio::select! { + _ = serve_future => panic!("server returned first"), + _ = request_future => println!("Client finished!"), + } + } + + async fn test_https_client<F, C>(f: F) + where + F: FnOnce(FlightSqlServiceClient<Channel>) -> C, + C: Future<Output = ()>, + { + let cert = std::fs::read_to_string("examples/data/server.pem").unwrap(); + let key = std::fs::read_to_string("examples/data/server.key").unwrap(); + let client_ca = std::fs::read_to_string("examples/data/client_ca.pem").unwrap(); + + let tls_config = ServerTlsConfig::new() + .identity(Identity::from_pem(&cert, &key)) + .client_ca_root(Certificate::from_pem(&client_ca)); + + let (incoming, addr) = bind_tcp().await; + let uri = format!("https://{}:{}", addr.ip(), addr.port()); + + let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); + + let serve_future = Server::builder() + .tls_config(tls_config) + .unwrap() + .add_service(svc) + .serve_with_incoming(incoming); + + let request_future = async { + let cert = std::fs::read_to_string("examples/data/client1.pem").unwrap(); + let key = std::fs::read_to_string("examples/data/client1.key").unwrap(); + let server_ca = std::fs::read_to_string("examples/data/ca.pem").unwrap(); + + let tls_config = ClientTlsConfig::new() + .domain_name("localhost") + .ca_certificate(Certificate::from_pem(&server_ca)) + .identity(Identity::from_pem(cert, key)); + + let endpoint = endpoint(uri).unwrap().tls_config(tls_config).unwrap(); + let channel = endpoint.connect().await.unwrap(); + let client = FlightSqlServiceClient::new(channel); f(client).await }; @@ -699,16 +695,38 @@ mod tests { } } + async fn test_all_clients<F, C>(task: F) + where + F: FnOnce(FlightSqlServiceClient<Channel>) -> C + Copy, + C: Future<Output = ()>, + { + println!("testing uds client"); + test_uds_client(task).await; + println!("======="); + + println!("testing http client"); + test_http_client(task).await; + println!("======="); + + println!("testing https client"); + test_https_client(task).await; + println!("======="); + } + #[tokio::test] - async fn test_select_1() { - test_client(|mut client| async move { + async fn test_select() { + test_all_clients(|mut client| async move { auth_client(&mut client).await; + let mut stmt = client.prepare("select 1;".to_string()).await.unwrap(); + let flight_info = stmt.execute().await.unwrap(); + let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone(); let flight_data = client.do_get(ticket).await.unwrap(); let flight_data: Vec<FlightData> = flight_data.try_collect().await.unwrap(); let batches = flight_data_to_batches(&flight_data).unwrap(); + let res = pretty_format_batches(batches.as_slice()).unwrap(); let expected = r#" +-------------------+ @@ -725,7 +743,7 @@ mod tests { #[tokio::test] async fn test_execute_update() { - test_client(|mut client| async move { + test_all_clients(|mut client| async move { auth_client(&mut client).await; let res = client .execute_update("creat table test(a int);".to_string()) @@ -738,7 +756,7 @@ mod tests { #[tokio::test] async fn test_auth() { - test_client(|mut client| async move { + test_all_clients(|mut client| async move { // no handshake assert!(client .prepare("select 1;".to_string())