avantgardnerio commented on code in PR #269: URL: https://github.com/apache/arrow-ballista/pull/269#discussion_r979037350
########## ballista/rust/scheduler/src/flight_sql.rs: ########## @@ -309,33 +362,140 @@ impl FlightSqlServiceImpl { impl FlightSqlService for FlightSqlServiceImpl { type FlightService = FlightSqlServiceImpl; + async fn do_handshake( + &self, + request: Request<Streaming<HandshakeRequest>>, + ) -> Result< + Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>, + Status, + > { + debug!("do_handshake"); + for md in request.metadata().iter() { + debug!("{:?}", md); + } + + let basic = "Basic "; + let authorization = request + .metadata() + .get("authorization") + .ok_or(Status::invalid_argument("authorization field not present"))? + .to_str() + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + if !authorization.starts_with(basic) { + Err(Status::invalid_argument(format!( + "Auth type not implemented: {}", + authorization + )))?; + } + let base64 = &authorization[basic.len()..]; + let bytes = base64::decode(base64) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let str = String::from_utf8(bytes) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let parts: Vec<_> = str.split(":").collect(); + if parts.len() != 2 { + Err(Status::invalid_argument(format!( + "Invalid authorization header" + )))?; + } + let user = parts[0]; + let pass = parts[1]; + if user != "admin" || pass != "password" { + Err(Status::unauthenticated("Invalid credentials!"))? + } + + let token = self.create_ctx().await?; + + let result = HandshakeResponse { + protocol_version: 0, + payload: token.as_bytes().to_vec(), Review Comment: We don't actually care about credentials at all, we just need auth to be active in order to get this handshake request and give them back a session token. ########## ballista/rust/scheduler/src/flight_sql.rs: ########## @@ -309,33 +362,140 @@ impl FlightSqlServiceImpl { impl FlightSqlService for FlightSqlServiceImpl { type FlightService = FlightSqlServiceImpl; + async fn do_handshake( + &self, + request: Request<Streaming<HandshakeRequest>>, + ) -> Result< + Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>, + Status, + > { + debug!("do_handshake"); + for md in request.metadata().iter() { + debug!("{:?}", md); + } + + let basic = "Basic "; + let authorization = request + .metadata() + .get("authorization") + .ok_or(Status::invalid_argument("authorization field not present"))? + .to_str() + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + if !authorization.starts_with(basic) { + Err(Status::invalid_argument(format!( + "Auth type not implemented: {}", + authorization + )))?; + } + let base64 = &authorization[basic.len()..]; + let bytes = base64::decode(base64) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let str = String::from_utf8(bytes) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let parts: Vec<_> = str.split(":").collect(); + if parts.len() != 2 { + Err(Status::invalid_argument(format!( + "Invalid authorization header" + )))?; + } + let user = parts[0]; + let pass = parts[1]; + if user != "admin" || pass != "password" { + Err(Status::unauthenticated("Invalid credentials!"))? + } + + let token = self.create_ctx().await?; + + let result = HandshakeResponse { + protocol_version: 0, + payload: token.as_bytes().to_vec(), + }; + let result = Ok(result); + let output = futures::stream::iter(vec![result]); + let str = format!("Bearer {}", token.to_string()); + let mut resp: Response<Pin<Box<dyn Stream<Item = Result<_, _>> + Send>>> = + Response::new(Box::pin(output)); + let md = MetadataValue::try_from(str) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + resp.metadata_mut().insert("authorization", md); + Ok(resp) + } + + async fn do_get_fallback( + &self, + _request: Request<Ticket>, + message: prost_types::Any, + ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { + println!("type_url: {}", message.type_url); + if message.is::<protobuf::Action>() { + println!("got action!"); + let action: protobuf::Action = message + .unpack() + .map_err(|e| Status::internal(format!("{:?}", e)))? + .ok_or(Status::internal("Expected an Action but got None!"))?; + println!("action={:?}", action); + let (host, port) = match &action.action_type { + Some(FetchPartition(fp)) => (fp.host.clone(), fp.port), + None => Err(Status::internal("Expected an ActionType but got None!"))?, + }; + + let addr = format!("http://{}:{}", host, port); + println!("BallistaClient connecting to {}", addr); + let connection = + create_grpc_client_connection(addr.clone()) + .await + .map_err(|e| { + Status::internal(format!( + "Error connecting to Ballista scheduler or executor at {}: {:?}", + addr, e + )) + })?; + let mut flight_client = FlightServiceClient::new(connection); Review Comment: Proxy the flight to the correct executor instance. ########## ballista/rust/scheduler/src/flight_sql.rs: ########## @@ -309,33 +362,140 @@ impl FlightSqlServiceImpl { impl FlightSqlService for FlightSqlServiceImpl { type FlightService = FlightSqlServiceImpl; + async fn do_handshake( + &self, + request: Request<Streaming<HandshakeRequest>>, + ) -> Result< + Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>, + Status, + > { + debug!("do_handshake"); + for md in request.metadata().iter() { + debug!("{:?}", md); + } + + let basic = "Basic "; + let authorization = request + .metadata() + .get("authorization") + .ok_or(Status::invalid_argument("authorization field not present"))? + .to_str() + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + if !authorization.starts_with(basic) { + Err(Status::invalid_argument(format!( + "Auth type not implemented: {}", + authorization + )))?; + } + let base64 = &authorization[basic.len()..]; + let bytes = base64::decode(base64) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let str = String::from_utf8(bytes) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let parts: Vec<_> = str.split(":").collect(); + if parts.len() != 2 { + Err(Status::invalid_argument(format!( + "Invalid authorization header" + )))?; + } + let user = parts[0]; + let pass = parts[1]; + if user != "admin" || pass != "password" { + Err(Status::unauthenticated("Invalid credentials!"))? + } + + let token = self.create_ctx().await?; + + let result = HandshakeResponse { + protocol_version: 0, + payload: token.as_bytes().to_vec(), + }; + let result = Ok(result); + let output = futures::stream::iter(vec![result]); + let str = format!("Bearer {}", token.to_string()); + let mut resp: Response<Pin<Box<dyn Stream<Item = Result<_, _>> + Send>>> = + Response::new(Box::pin(output)); + let md = MetadataValue::try_from(str) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + resp.metadata_mut().insert("authorization", md); + Ok(resp) + } + + async fn do_get_fallback( + &self, + _request: Request<Ticket>, + message: prost_types::Any, + ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> { + println!("type_url: {}", message.type_url); + if message.is::<protobuf::Action>() { + println!("got action!"); + let action: protobuf::Action = message + .unpack() + .map_err(|e| Status::internal(format!("{:?}", e)))? + .ok_or(Status::internal("Expected an Action but got None!"))?; + println!("action={:?}", action); + let (host, port) = match &action.action_type { + Some(FetchPartition(fp)) => (fp.host.clone(), fp.port), + None => Err(Status::internal("Expected an ActionType but got None!"))?, + }; + + let addr = format!("http://{}:{}", host, port); + println!("BallistaClient connecting to {}", addr); + let connection = + create_grpc_client_connection(addr.clone()) + .await + .map_err(|e| { + Status::internal(format!( + "Error connecting to Ballista scheduler or executor at {}: {:?}", + addr, e + )) + })?; + let mut flight_client = FlightServiceClient::new(connection); + let buf = action.encode_to_vec(); + let request = Request::new(Ticket { ticket: buf }); + + let stream = flight_client + .do_get(request) + .await + .map_err(|e| Status::internal(format!("{:?}", e)))? + .into_inner(); + return Ok(Response::new(Box::pin(stream))); + } + + Err(Status::unimplemented(format!( + "do_get: The defined request is invalid: {}", + message.type_url + ))) + } + async fn get_flight_info_statement( &self, query: CommandStatementQuery, - _request: Request<FlightDescriptor>, + request: Request<FlightDescriptor>, ) -> Result<Response<FlightInfo>, Status> { - debug!("Got query:\n{}", query.query); + debug!("get_flight_info_statement query:\n{}", query.query); - let ctx = self.create_ctx().await?; + let ctx = self.get_ctx(&request)?; let plan = Self::prepare_statement(&query.query, &ctx).await?; let resp = self.execute_plan(ctx, &plan).await?; - debug!("Responding to query..."); + debug!("Returning flight info..."); Ok(resp) } async fn get_flight_info_prepared_statement( &self, handle: CommandPreparedStatementQuery, - _request: Request<FlightDescriptor>, + request: Request<FlightDescriptor>, ) -> Result<Response<FlightInfo>, Status> { - let ctx = self.create_ctx().await?; + debug!("get_flight_info_prepared_statement"); + let ctx = self.get_ctx(&request)?; Review Comment: Use the UUID token to get the cached `SessionContext`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org