avantgardnerio commented on code in PR #269:
URL: https://github.com/apache/arrow-ballista/pull/269#discussion_r979037350


##########
ballista/rust/scheduler/src/flight_sql.rs:
##########
@@ -309,33 +362,140 @@ impl FlightSqlServiceImpl {
 impl FlightSqlService for FlightSqlServiceImpl {
     type FlightService = FlightSqlServiceImpl;
 
+    async fn do_handshake(
+        &self,
+        request: Request<Streaming<HandshakeRequest>>,
+    ) -> Result<
+        Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> 
+ Send>>>,
+        Status,
+    > {
+        debug!("do_handshake");
+        for md in request.metadata().iter() {
+            debug!("{:?}", md);
+        }
+
+        let basic = "Basic ";
+        let authorization = request
+            .metadata()
+            .get("authorization")
+            .ok_or(Status::invalid_argument("authorization field not 
present"))?
+            .to_str()
+            .map_err(|_| Status::invalid_argument("authorization not 
parsable"))?;
+        if !authorization.starts_with(basic) {
+            Err(Status::invalid_argument(format!(
+                "Auth type not implemented: {}",
+                authorization
+            )))?;
+        }
+        let base64 = &authorization[basic.len()..];
+        let bytes = base64::decode(base64)
+            .map_err(|_| Status::invalid_argument("authorization not 
parsable"))?;
+        let str = String::from_utf8(bytes)
+            .map_err(|_| Status::invalid_argument("authorization not 
parsable"))?;
+        let parts: Vec<_> = str.split(":").collect();
+        if parts.len() != 2 {
+            Err(Status::invalid_argument(format!(
+                "Invalid authorization header"
+            )))?;
+        }
+        let user = parts[0];
+        let pass = parts[1];
+        if user != "admin" || pass != "password" {
+            Err(Status::unauthenticated("Invalid credentials!"))?
+        }
+
+        let token = self.create_ctx().await?;
+
+        let result = HandshakeResponse {
+            protocol_version: 0,
+            payload: token.as_bytes().to_vec(),

Review Comment:
   We don't actually care about credentials at all, we just need auth to be 
active in order to get this handshake request and give them back a session 
token.



##########
ballista/rust/scheduler/src/flight_sql.rs:
##########
@@ -309,33 +362,140 @@ impl FlightSqlServiceImpl {
 impl FlightSqlService for FlightSqlServiceImpl {
     type FlightService = FlightSqlServiceImpl;
 
+    async fn do_handshake(
+        &self,
+        request: Request<Streaming<HandshakeRequest>>,
+    ) -> Result<
+        Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> 
+ Send>>>,
+        Status,
+    > {
+        debug!("do_handshake");
+        for md in request.metadata().iter() {
+            debug!("{:?}", md);
+        }
+
+        let basic = "Basic ";
+        let authorization = request
+            .metadata()
+            .get("authorization")
+            .ok_or(Status::invalid_argument("authorization field not 
present"))?
+            .to_str()
+            .map_err(|_| Status::invalid_argument("authorization not 
parsable"))?;
+        if !authorization.starts_with(basic) {
+            Err(Status::invalid_argument(format!(
+                "Auth type not implemented: {}",
+                authorization
+            )))?;
+        }
+        let base64 = &authorization[basic.len()..];
+        let bytes = base64::decode(base64)
+            .map_err(|_| Status::invalid_argument("authorization not 
parsable"))?;
+        let str = String::from_utf8(bytes)
+            .map_err(|_| Status::invalid_argument("authorization not 
parsable"))?;
+        let parts: Vec<_> = str.split(":").collect();
+        if parts.len() != 2 {
+            Err(Status::invalid_argument(format!(
+                "Invalid authorization header"
+            )))?;
+        }
+        let user = parts[0];
+        let pass = parts[1];
+        if user != "admin" || pass != "password" {
+            Err(Status::unauthenticated("Invalid credentials!"))?
+        }
+
+        let token = self.create_ctx().await?;
+
+        let result = HandshakeResponse {
+            protocol_version: 0,
+            payload: token.as_bytes().to_vec(),
+        };
+        let result = Ok(result);
+        let output = futures::stream::iter(vec![result]);
+        let str = format!("Bearer {}", token.to_string());
+        let mut resp: Response<Pin<Box<dyn Stream<Item = Result<_, _>> + 
Send>>> =
+            Response::new(Box::pin(output));
+        let md = MetadataValue::try_from(str)
+            .map_err(|_| Status::invalid_argument("authorization not 
parsable"))?;
+        resp.metadata_mut().insert("authorization", md);
+        Ok(resp)
+    }
+
+    async fn do_get_fallback(
+        &self,
+        _request: Request<Ticket>,
+        message: prost_types::Any,
+    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
+        println!("type_url: {}", message.type_url);
+        if message.is::<protobuf::Action>() {
+            println!("got action!");
+            let action: protobuf::Action = message
+                .unpack()
+                .map_err(|e| Status::internal(format!("{:?}", e)))?
+                .ok_or(Status::internal("Expected an Action but got None!"))?;
+            println!("action={:?}", action);
+            let (host, port) = match &action.action_type {
+                Some(FetchPartition(fp)) => (fp.host.clone(), fp.port),
+                None => Err(Status::internal("Expected an ActionType but got 
None!"))?,
+            };
+
+            let addr = format!("http://{}:{}";, host, port);
+            println!("BallistaClient connecting to {}", addr);
+            let connection =
+                create_grpc_client_connection(addr.clone())
+                    .await
+                    .map_err(|e| {
+                        Status::internal(format!(
+                        "Error connecting to Ballista scheduler or executor at 
{}: {:?}",
+                        addr, e
+                    ))
+                    })?;
+            let mut flight_client = FlightServiceClient::new(connection);

Review Comment:
   Proxy the flight to the correct executor instance.



##########
ballista/rust/scheduler/src/flight_sql.rs:
##########
@@ -309,33 +362,140 @@ impl FlightSqlServiceImpl {
 impl FlightSqlService for FlightSqlServiceImpl {
     type FlightService = FlightSqlServiceImpl;
 
+    async fn do_handshake(
+        &self,
+        request: Request<Streaming<HandshakeRequest>>,
+    ) -> Result<
+        Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> 
+ Send>>>,
+        Status,
+    > {
+        debug!("do_handshake");
+        for md in request.metadata().iter() {
+            debug!("{:?}", md);
+        }
+
+        let basic = "Basic ";
+        let authorization = request
+            .metadata()
+            .get("authorization")
+            .ok_or(Status::invalid_argument("authorization field not 
present"))?
+            .to_str()
+            .map_err(|_| Status::invalid_argument("authorization not 
parsable"))?;
+        if !authorization.starts_with(basic) {
+            Err(Status::invalid_argument(format!(
+                "Auth type not implemented: {}",
+                authorization
+            )))?;
+        }
+        let base64 = &authorization[basic.len()..];
+        let bytes = base64::decode(base64)
+            .map_err(|_| Status::invalid_argument("authorization not 
parsable"))?;
+        let str = String::from_utf8(bytes)
+            .map_err(|_| Status::invalid_argument("authorization not 
parsable"))?;
+        let parts: Vec<_> = str.split(":").collect();
+        if parts.len() != 2 {
+            Err(Status::invalid_argument(format!(
+                "Invalid authorization header"
+            )))?;
+        }
+        let user = parts[0];
+        let pass = parts[1];
+        if user != "admin" || pass != "password" {
+            Err(Status::unauthenticated("Invalid credentials!"))?
+        }
+
+        let token = self.create_ctx().await?;
+
+        let result = HandshakeResponse {
+            protocol_version: 0,
+            payload: token.as_bytes().to_vec(),
+        };
+        let result = Ok(result);
+        let output = futures::stream::iter(vec![result]);
+        let str = format!("Bearer {}", token.to_string());
+        let mut resp: Response<Pin<Box<dyn Stream<Item = Result<_, _>> + 
Send>>> =
+            Response::new(Box::pin(output));
+        let md = MetadataValue::try_from(str)
+            .map_err(|_| Status::invalid_argument("authorization not 
parsable"))?;
+        resp.metadata_mut().insert("authorization", md);
+        Ok(resp)
+    }
+
+    async fn do_get_fallback(
+        &self,
+        _request: Request<Ticket>,
+        message: prost_types::Any,
+    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
+        println!("type_url: {}", message.type_url);
+        if message.is::<protobuf::Action>() {
+            println!("got action!");
+            let action: protobuf::Action = message
+                .unpack()
+                .map_err(|e| Status::internal(format!("{:?}", e)))?
+                .ok_or(Status::internal("Expected an Action but got None!"))?;
+            println!("action={:?}", action);
+            let (host, port) = match &action.action_type {
+                Some(FetchPartition(fp)) => (fp.host.clone(), fp.port),
+                None => Err(Status::internal("Expected an ActionType but got 
None!"))?,
+            };
+
+            let addr = format!("http://{}:{}";, host, port);
+            println!("BallistaClient connecting to {}", addr);
+            let connection =
+                create_grpc_client_connection(addr.clone())
+                    .await
+                    .map_err(|e| {
+                        Status::internal(format!(
+                        "Error connecting to Ballista scheduler or executor at 
{}: {:?}",
+                        addr, e
+                    ))
+                    })?;
+            let mut flight_client = FlightServiceClient::new(connection);
+            let buf = action.encode_to_vec();
+            let request = Request::new(Ticket { ticket: buf });
+
+            let stream = flight_client
+                .do_get(request)
+                .await
+                .map_err(|e| Status::internal(format!("{:?}", e)))?
+                .into_inner();
+            return Ok(Response::new(Box::pin(stream)));
+        }
+
+        Err(Status::unimplemented(format!(
+            "do_get: The defined request is invalid: {}",
+            message.type_url
+        )))
+    }
+
     async fn get_flight_info_statement(
         &self,
         query: CommandStatementQuery,
-        _request: Request<FlightDescriptor>,
+        request: Request<FlightDescriptor>,
     ) -> Result<Response<FlightInfo>, Status> {
-        debug!("Got query:\n{}", query.query);
+        debug!("get_flight_info_statement query:\n{}", query.query);
 
-        let ctx = self.create_ctx().await?;
+        let ctx = self.get_ctx(&request)?;
         let plan = Self::prepare_statement(&query.query, &ctx).await?;
         let resp = self.execute_plan(ctx, &plan).await?;
 
-        debug!("Responding to query...");
+        debug!("Returning flight info...");
         Ok(resp)
     }
 
     async fn get_flight_info_prepared_statement(
         &self,
         handle: CommandPreparedStatementQuery,
-        _request: Request<FlightDescriptor>,
+        request: Request<FlightDescriptor>,
     ) -> Result<Response<FlightInfo>, Status> {
-        let ctx = self.create_ctx().await?;
+        debug!("get_flight_info_prepared_statement");
+        let ctx = self.get_ctx(&request)?;

Review Comment:
   Use the UUID token to get the cached `SessionContext`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to