martin-g commented on code in PR #1351:
URL:
https://github.com/apache/datafusion-ballista/pull/1351#discussion_r2634001951
##########
ballista/scheduler/src/flight_proxy_service.rs:
##########
@@ -0,0 +1,157 @@
+use arrow_flight::flight_service_client::FlightServiceClient;
+use arrow_flight::flight_service_server::FlightService;
+use arrow_flight::{
+ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor,
FlightInfo,
+ HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult,
Ticket,
+};
+use ballista_core::error::BallistaError;
+use ballista_core::serde::decode_protobuf;
+use ballista_core::serde::scheduler::Action as BallistaAction;
+use ballista_core::utils::{create_grpc_client_connection, GrpcClientConfig};
+use futures::{Stream, TryFutureExt};
+use log::debug;
+use std::pin::Pin;
+use tonic::{Request, Response, Status, Streaming};
+
+/// Service implementing a proxy from scheduler to executor Apache Arrow
Flight Protocol
+#[derive(Clone)]
+pub struct BallistaFlightProxyService {}
+
+impl BallistaFlightProxyService {
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+impl Default for BallistaFlightProxyService {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+type BoxedFlightStream<T> =
+ Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>;
+
+#[tonic::async_trait]
+impl FlightService for BallistaFlightProxyService {
+ type DoActionStream = BoxedFlightStream<arrow_flight::Result>;
+ type DoExchangeStream = BoxedFlightStream<FlightData>;
+ type DoGetStream = BoxedFlightStream<FlightData>;
+ type DoPutStream = BoxedFlightStream<PutResult>;
+ type HandshakeStream = BoxedFlightStream<HandshakeResponse>;
+ type ListActionsStream = BoxedFlightStream<ActionType>;
+ type ListFlightsStream = BoxedFlightStream<FlightInfo>;
+ async fn handshake(
+ &self,
+ _request: Request<Streaming<HandshakeRequest>>,
+ ) -> Result<Response<Self::HandshakeStream>, Status> {
+ Err(Status::unimplemented("handshake"))
+ }
+
+ async fn list_flights(
+ &self,
+ _request: Request<Criteria>,
+ ) -> Result<Response<Self::ListFlightsStream>, Status> {
+ Err(Status::unimplemented("list_flights"))
+ }
+
+ async fn get_flight_info(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<FlightInfo>, Status> {
+ Err(Status::unimplemented("get_flight_info"))
+ }
+
+ async fn poll_flight_info(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<PollInfo>, Status> {
+ Err(Status::unimplemented("poll_flight_info"))
+ }
+
+ async fn get_schema(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<SchemaResult>, Status> {
+ Err(Status::unimplemented("get_schema"))
+ }
+
+ async fn do_get(
+ &self,
+ request: Request<Ticket>,
+ ) -> Result<Response<Self::DoGetStream>, Status> {
+ let ticket = request.into_inner();
+
+ let action =
+ decode_protobuf(&ticket.ticket).map_err(|e|
from_ballista_err(&e))?;
+
+ match &action {
+ BallistaAction::FetchPartition {
+ host, port, job_id, ..
+ } => {
+ debug!("Fetching results for job id: {job_id} from
{host}:{port}");
+ let mut client = get_flight_client(host, port)
+ .map_err(|e| from_ballista_err(&e))
+ .await?;
+ client
+ .do_get(Request::new(ticket))
+ .await
+ .map(|r| Response::new(Box::pin(r.into_inner()) as
Self::DoGetStream))
+ }
+ }
+ }
+
+ async fn do_put(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoPutStream>, Status> {
+ Err(Status::unimplemented("do_put"))
+ }
+
+ async fn do_exchange(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoExchangeStream>, Status> {
+ Err(Status::unimplemented("do_exchange"))
+ }
+
+ async fn do_action(
+ &self,
+ _request: Request<Action>,
+ ) -> Result<Response<Self::DoActionStream>, Status> {
+ Err(Status::unimplemented("do_action"))
+ }
+
+ async fn list_actions(
+ &self,
+ _request: Request<Empty>,
+ ) -> Result<Response<Self::ListActionsStream>, Status> {
+ Err(Status::unimplemented("list_actions"))
+ }
+}
+
+fn from_ballista_err(e: &ballista_core::error::BallistaError) -> Status {
+ Status::internal(format!("Ballista Error: {e:?}"))
+}
+
+async fn get_flight_client(
+ host: &String,
Review Comment:
```suggestion
host: &str,
```
##########
ballista/core/src/execution_plans/distributed_query.rs:
##########
@@ -360,9 +363,22 @@ async fn execute_query(
let duration = Duration::from_millis(duration);
info!("Job {job_id} finished executing in {duration:?} ");
+ let FlightEndpointInfo {
+ address: flight_proxy_address,
+ } = scheduler
+ .get_flight_endpoint_info(FlightEndpointInfoParams {})
+ .await
+ .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
Review Comment:
Thinking out loud: Does it need to fail here or it is OK to log it and
fallback to no-proxy ?
##########
ballista/scheduler/src/scheduler_process.rs:
##########
@@ -131,7 +176,19 @@ pub async fn start_server(
) -> ballista_core::error::Result<()> {
info!("Ballista v{BALLISTA_VERSION} Scheduler listening on {address:?}");
let scheduler =
- create_scheduler::<LogicalPlanNode, PhysicalPlanNode>(cluster,
config).await?;
+ create_scheduler::<LogicalPlanNode, PhysicalPlanNode>(cluster,
config.clone())
+ .await?;
- start_grpc_service(address, scheduler).await
+ info!(
+ "advertise_flight_sql_endpoint: {:?}",
+ config.advertise_flight_sql_endpoint
+ );
+ match config.advertise_flight_sql_endpoint {
+ Some(_) => {
+ info!("Starting flight proxy");
+ let _flight_proxy = start_flight_proxy_server(config);
Review Comment:
This ignores any errors returned by the the proxy. They are logged, so this
might be intentional - to continue working even without the proxy.
If both services should fail if either of them fail then you could use
`select!()`. E.g.:
```diff
- let _flight_proxy = start_flight_proxy_server(config);
- start_grpc_service(address, scheduler).await
+ let flight_proxy = start_flight_proxy_server(config);
+ tokio::select! {
+ result = start_grpc_service(address, scheduler) => result,
+ result = flight_proxy => {
+ result.map_err(|e|
BallistaError::Internal(format!("Flight proxy task panicked: {e:?}")))?
+ }
+ }
```
Also there is a racing here - the main service may start faster and any
requests won't be routed thru the proxy until it fully starts too.
##########
ballista/scheduler/src/flight_proxy_service.rs:
##########
@@ -0,0 +1,157 @@
+use arrow_flight::flight_service_client::FlightServiceClient;
+use arrow_flight::flight_service_server::FlightService;
+use arrow_flight::{
+ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor,
FlightInfo,
+ HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult,
Ticket,
+};
+use ballista_core::error::BallistaError;
+use ballista_core::serde::decode_protobuf;
+use ballista_core::serde::scheduler::Action as BallistaAction;
+use ballista_core::utils::{create_grpc_client_connection, GrpcClientConfig};
+use futures::{Stream, TryFutureExt};
+use log::debug;
+use std::pin::Pin;
+use tonic::{Request, Response, Status, Streaming};
+
+/// Service implementing a proxy from scheduler to executor Apache Arrow
Flight Protocol
+#[derive(Clone)]
+pub struct BallistaFlightProxyService {}
+
+impl BallistaFlightProxyService {
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+impl Default for BallistaFlightProxyService {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+type BoxedFlightStream<T> =
+ Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>;
+
+#[tonic::async_trait]
+impl FlightService for BallistaFlightProxyService {
+ type DoActionStream = BoxedFlightStream<arrow_flight::Result>;
+ type DoExchangeStream = BoxedFlightStream<FlightData>;
+ type DoGetStream = BoxedFlightStream<FlightData>;
+ type DoPutStream = BoxedFlightStream<PutResult>;
+ type HandshakeStream = BoxedFlightStream<HandshakeResponse>;
+ type ListActionsStream = BoxedFlightStream<ActionType>;
+ type ListFlightsStream = BoxedFlightStream<FlightInfo>;
+ async fn handshake(
+ &self,
+ _request: Request<Streaming<HandshakeRequest>>,
+ ) -> Result<Response<Self::HandshakeStream>, Status> {
+ Err(Status::unimplemented("handshake"))
+ }
+
+ async fn list_flights(
+ &self,
+ _request: Request<Criteria>,
+ ) -> Result<Response<Self::ListFlightsStream>, Status> {
+ Err(Status::unimplemented("list_flights"))
+ }
+
+ async fn get_flight_info(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<FlightInfo>, Status> {
+ Err(Status::unimplemented("get_flight_info"))
+ }
+
+ async fn poll_flight_info(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<PollInfo>, Status> {
+ Err(Status::unimplemented("poll_flight_info"))
+ }
+
+ async fn get_schema(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<SchemaResult>, Status> {
+ Err(Status::unimplemented("get_schema"))
+ }
+
+ async fn do_get(
+ &self,
+ request: Request<Ticket>,
+ ) -> Result<Response<Self::DoGetStream>, Status> {
+ let ticket = request.into_inner();
+
+ let action =
+ decode_protobuf(&ticket.ticket).map_err(|e|
from_ballista_err(&e))?;
+
+ match &action {
+ BallistaAction::FetchPartition {
+ host, port, job_id, ..
+ } => {
+ debug!("Fetching results for job id: {job_id} from
{host}:{port}");
+ let mut client = get_flight_client(host, port)
+ .map_err(|e| from_ballista_err(&e))
+ .await?;
+ client
+ .do_get(Request::new(ticket))
+ .await
+ .map(|r| Response::new(Box::pin(r.into_inner()) as
Self::DoGetStream))
+ }
+ }
+ }
+
+ async fn do_put(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoPutStream>, Status> {
+ Err(Status::unimplemented("do_put"))
+ }
+
+ async fn do_exchange(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoExchangeStream>, Status> {
+ Err(Status::unimplemented("do_exchange"))
+ }
+
+ async fn do_action(
+ &self,
+ _request: Request<Action>,
+ ) -> Result<Response<Self::DoActionStream>, Status> {
+ Err(Status::unimplemented("do_action"))
+ }
+
+ async fn list_actions(
+ &self,
+ _request: Request<Empty>,
+ ) -> Result<Response<Self::ListActionsStream>, Status> {
+ Err(Status::unimplemented("list_actions"))
+ }
+}
+
+fn from_ballista_err(e: &ballista_core::error::BallistaError) -> Status {
+ Status::internal(format!("Ballista Error: {e:?}"))
+}
+
+async fn get_flight_client(
+ host: &String,
+ port: &u16,
+) -> Result<FlightServiceClient<tonic::transport::channel::Channel>,
BallistaError> {
+ let addr = format!("http://{host}:{port}");
+ let grpc_config = GrpcClientConfig::default();
+ debug!("FlightProxyService connecting to {addr}");
+ let connection = create_grpc_client_connection(addr.clone(), &grpc_config)
+ .await
+ .map_err(|e| {
+ BallistaError::GrpcConnectionError(format!(
+ "Error connecting to Ballista scheduler or executor at {addr}:
{e:?}"
+ ))
+ })?;
+ let flight_client = FlightServiceClient::new(connection)
+ .max_decoding_message_size(16 * 1024 * 1024)
Review Comment:
Can/should we use `config.grpc_server_max_encoding_message_size` or a new
setting ?
Same for `min` below.
##########
ballista/scheduler/src/flight_proxy_service.rs:
##########
@@ -0,0 +1,157 @@
+use arrow_flight::flight_service_client::FlightServiceClient;
+use arrow_flight::flight_service_server::FlightService;
+use arrow_flight::{
+ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor,
FlightInfo,
+ HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult,
Ticket,
+};
+use ballista_core::error::BallistaError;
+use ballista_core::serde::decode_protobuf;
+use ballista_core::serde::scheduler::Action as BallistaAction;
+use ballista_core::utils::{create_grpc_client_connection, GrpcClientConfig};
+use futures::{Stream, TryFutureExt};
+use log::debug;
+use std::pin::Pin;
+use tonic::{Request, Response, Status, Streaming};
+
+/// Service implementing a proxy from scheduler to executor Apache Arrow
Flight Protocol
+#[derive(Clone)]
+pub struct BallistaFlightProxyService {}
+
+impl BallistaFlightProxyService {
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+impl Default for BallistaFlightProxyService {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+type BoxedFlightStream<T> =
+ Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>;
+
+#[tonic::async_trait]
+impl FlightService for BallistaFlightProxyService {
+ type DoActionStream = BoxedFlightStream<arrow_flight::Result>;
+ type DoExchangeStream = BoxedFlightStream<FlightData>;
+ type DoGetStream = BoxedFlightStream<FlightData>;
+ type DoPutStream = BoxedFlightStream<PutResult>;
+ type HandshakeStream = BoxedFlightStream<HandshakeResponse>;
+ type ListActionsStream = BoxedFlightStream<ActionType>;
+ type ListFlightsStream = BoxedFlightStream<FlightInfo>;
+ async fn handshake(
+ &self,
+ _request: Request<Streaming<HandshakeRequest>>,
+ ) -> Result<Response<Self::HandshakeStream>, Status> {
+ Err(Status::unimplemented("handshake"))
+ }
+
+ async fn list_flights(
+ &self,
+ _request: Request<Criteria>,
+ ) -> Result<Response<Self::ListFlightsStream>, Status> {
+ Err(Status::unimplemented("list_flights"))
+ }
+
+ async fn get_flight_info(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<FlightInfo>, Status> {
+ Err(Status::unimplemented("get_flight_info"))
+ }
+
+ async fn poll_flight_info(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<PollInfo>, Status> {
+ Err(Status::unimplemented("poll_flight_info"))
+ }
+
+ async fn get_schema(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<SchemaResult>, Status> {
+ Err(Status::unimplemented("get_schema"))
+ }
+
+ async fn do_get(
+ &self,
+ request: Request<Ticket>,
+ ) -> Result<Response<Self::DoGetStream>, Status> {
+ let ticket = request.into_inner();
+
+ let action =
+ decode_protobuf(&ticket.ticket).map_err(|e|
from_ballista_err(&e))?;
+
+ match &action {
+ BallistaAction::FetchPartition {
+ host, port, job_id, ..
+ } => {
+ debug!("Fetching results for job id: {job_id} from
{host}:{port}");
+ let mut client = get_flight_client(host, port)
+ .map_err(|e| from_ballista_err(&e))
+ .await?;
+ client
+ .do_get(Request::new(ticket))
+ .await
+ .map(|r| Response::new(Box::pin(r.into_inner()) as
Self::DoGetStream))
+ }
+ }
+ }
+
+ async fn do_put(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoPutStream>, Status> {
+ Err(Status::unimplemented("do_put"))
+ }
+
+ async fn do_exchange(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoExchangeStream>, Status> {
+ Err(Status::unimplemented("do_exchange"))
+ }
+
+ async fn do_action(
+ &self,
+ _request: Request<Action>,
+ ) -> Result<Response<Self::DoActionStream>, Status> {
+ Err(Status::unimplemented("do_action"))
+ }
+
+ async fn list_actions(
+ &self,
+ _request: Request<Empty>,
+ ) -> Result<Response<Self::ListActionsStream>, Status> {
+ Err(Status::unimplemented("list_actions"))
+ }
+}
+
+fn from_ballista_err(e: &ballista_core::error::BallistaError) -> Status {
+ Status::internal(format!("Ballista Error: {e:?}"))
+}
+
+async fn get_flight_client(
+ host: &String,
+ port: &u16,
Review Comment:
```suggestion
port: u16,
```
##########
ballista/scheduler/src/scheduler_process.rs:
##########
@@ -121,6 +123,49 @@ pub async fn start_grpc_service<
.map_err(BallistaError::from)
}
+fn start_flight_proxy_server(
+ config: Arc<SchedulerConfig>,
+) -> JoinHandle<Result<(), BallistaError>> {
+ tokio::spawn(async move {
+ let address = match config.advertise_flight_sql_endpoint.clone() {
+ Some(flight_sql_endpoint) => flight_sql_endpoint
+ .parse::<SocketAddr>()
+ .map_err(|e: std::net::AddrParseError| {
+ error!(
+ "Error parsing advertise_flight_sql_endpoint: {}",
+ e.to_string()
+ );
+ BallistaError::Configuration(e.to_string())
+ })?,
+ _ => {
+ return Err(BallistaError::Configuration(
+ "Expected advertise flight sql endpoint".into(),
+ ));
+ }
+ };
+
+ let max_encoding_message_size =
+ config.grpc_server_max_encoding_message_size as usize;
+ let max_decoding_message_size =
+ config.grpc_server_max_decoding_message_size as usize;
+ info!("Built-in arrow flight server proxy listening on: {address:?}
max_encoding_size: {max_encoding_message_size} max_decoding_size:
{max_decoding_message_size}");
+
+ let grpc_server_config = GrpcServerConfig::default();
+ let server_future = create_grpc_server(&grpc_server_config)
Review Comment:
There is no authentication layer.
But there is no authentication for the main service too, so this is not
required at the moment.
##########
ballista/scheduler/src/flight_proxy_service.rs:
##########
@@ -0,0 +1,157 @@
+use arrow_flight::flight_service_client::FlightServiceClient;
+use arrow_flight::flight_service_server::FlightService;
+use arrow_flight::{
+ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor,
FlightInfo,
+ HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult,
Ticket,
+};
+use ballista_core::error::BallistaError;
+use ballista_core::serde::decode_protobuf;
+use ballista_core::serde::scheduler::Action as BallistaAction;
+use ballista_core::utils::{create_grpc_client_connection, GrpcClientConfig};
+use futures::{Stream, TryFutureExt};
+use log::debug;
+use std::pin::Pin;
+use tonic::{Request, Response, Status, Streaming};
+
+/// Service implementing a proxy from scheduler to executor Apache Arrow
Flight Protocol
+#[derive(Clone)]
+pub struct BallistaFlightProxyService {}
+
+impl BallistaFlightProxyService {
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+impl Default for BallistaFlightProxyService {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+type BoxedFlightStream<T> =
+ Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>;
+
+#[tonic::async_trait]
+impl FlightService for BallistaFlightProxyService {
+ type DoActionStream = BoxedFlightStream<arrow_flight::Result>;
+ type DoExchangeStream = BoxedFlightStream<FlightData>;
+ type DoGetStream = BoxedFlightStream<FlightData>;
+ type DoPutStream = BoxedFlightStream<PutResult>;
+ type HandshakeStream = BoxedFlightStream<HandshakeResponse>;
+ type ListActionsStream = BoxedFlightStream<ActionType>;
+ type ListFlightsStream = BoxedFlightStream<FlightInfo>;
+ async fn handshake(
+ &self,
+ _request: Request<Streaming<HandshakeRequest>>,
+ ) -> Result<Response<Self::HandshakeStream>, Status> {
+ Err(Status::unimplemented("handshake"))
+ }
+
+ async fn list_flights(
+ &self,
+ _request: Request<Criteria>,
+ ) -> Result<Response<Self::ListFlightsStream>, Status> {
+ Err(Status::unimplemented("list_flights"))
+ }
+
+ async fn get_flight_info(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<FlightInfo>, Status> {
+ Err(Status::unimplemented("get_flight_info"))
+ }
+
+ async fn poll_flight_info(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<PollInfo>, Status> {
+ Err(Status::unimplemented("poll_flight_info"))
+ }
+
+ async fn get_schema(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<SchemaResult>, Status> {
+ Err(Status::unimplemented("get_schema"))
+ }
+
+ async fn do_get(
+ &self,
+ request: Request<Ticket>,
+ ) -> Result<Response<Self::DoGetStream>, Status> {
+ let ticket = request.into_inner();
+
+ let action =
+ decode_protobuf(&ticket.ticket).map_err(|e|
from_ballista_err(&e))?;
+
+ match &action {
+ BallistaAction::FetchPartition {
+ host, port, job_id, ..
+ } => {
+ debug!("Fetching results for job id: {job_id} from
{host}:{port}");
+ let mut client = get_flight_client(host, port)
+ .map_err(|e| from_ballista_err(&e))
+ .await?;
+ client
+ .do_get(Request::new(ticket))
+ .await
+ .map(|r| Response::new(Box::pin(r.into_inner()) as
Self::DoGetStream))
+ }
+ }
+ }
+
+ async fn do_put(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoPutStream>, Status> {
+ Err(Status::unimplemented("do_put"))
+ }
+
+ async fn do_exchange(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoExchangeStream>, Status> {
+ Err(Status::unimplemented("do_exchange"))
+ }
+
+ async fn do_action(
+ &self,
+ _request: Request<Action>,
+ ) -> Result<Response<Self::DoActionStream>, Status> {
+ Err(Status::unimplemented("do_action"))
+ }
+
+ async fn list_actions(
+ &self,
+ _request: Request<Empty>,
+ ) -> Result<Response<Self::ListActionsStream>, Status> {
+ Err(Status::unimplemented("list_actions"))
+ }
+}
+
+fn from_ballista_err(e: &ballista_core::error::BallistaError) -> Status {
+ Status::internal(format!("Ballista Error: {e:?}"))
+}
+
+async fn get_flight_client(
+ host: &String,
+ port: &u16,
+) -> Result<FlightServiceClient<tonic::transport::channel::Channel>,
BallistaError> {
+ let addr = format!("http://{host}:{port}");
+ let grpc_config = GrpcClientConfig::default();
+ debug!("FlightProxyService connecting to {addr}");
+ let connection = create_grpc_client_connection(addr.clone(), &grpc_config)
Review Comment:
Since the host & port come from client provided ticket I think it would be
good to check somehow that they point to a registered executor. Otherwise it
might be used for Server-Side Request Forgery attack.
##########
ballista/scheduler/src/flight_proxy_service.rs:
##########
@@ -0,0 +1,157 @@
+use arrow_flight::flight_service_client::FlightServiceClient;
+use arrow_flight::flight_service_server::FlightService;
+use arrow_flight::{
+ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor,
FlightInfo,
+ HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult,
Ticket,
+};
+use ballista_core::error::BallistaError;
+use ballista_core::serde::decode_protobuf;
+use ballista_core::serde::scheduler::Action as BallistaAction;
+use ballista_core::utils::{create_grpc_client_connection, GrpcClientConfig};
+use futures::{Stream, TryFutureExt};
+use log::debug;
+use std::pin::Pin;
+use tonic::{Request, Response, Status, Streaming};
+
+/// Service implementing a proxy from scheduler to executor Apache Arrow
Flight Protocol
+#[derive(Clone)]
+pub struct BallistaFlightProxyService {}
+
+impl BallistaFlightProxyService {
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+impl Default for BallistaFlightProxyService {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+type BoxedFlightStream<T> =
+ Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>;
+
+#[tonic::async_trait]
+impl FlightService for BallistaFlightProxyService {
+ type DoActionStream = BoxedFlightStream<arrow_flight::Result>;
+ type DoExchangeStream = BoxedFlightStream<FlightData>;
+ type DoGetStream = BoxedFlightStream<FlightData>;
+ type DoPutStream = BoxedFlightStream<PutResult>;
+ type HandshakeStream = BoxedFlightStream<HandshakeResponse>;
+ type ListActionsStream = BoxedFlightStream<ActionType>;
+ type ListFlightsStream = BoxedFlightStream<FlightInfo>;
+ async fn handshake(
+ &self,
+ _request: Request<Streaming<HandshakeRequest>>,
+ ) -> Result<Response<Self::HandshakeStream>, Status> {
+ Err(Status::unimplemented("handshake"))
+ }
+
+ async fn list_flights(
+ &self,
+ _request: Request<Criteria>,
+ ) -> Result<Response<Self::ListFlightsStream>, Status> {
+ Err(Status::unimplemented("list_flights"))
+ }
+
+ async fn get_flight_info(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<FlightInfo>, Status> {
+ Err(Status::unimplemented("get_flight_info"))
+ }
+
+ async fn poll_flight_info(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<PollInfo>, Status> {
+ Err(Status::unimplemented("poll_flight_info"))
+ }
+
+ async fn get_schema(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<SchemaResult>, Status> {
+ Err(Status::unimplemented("get_schema"))
+ }
+
+ async fn do_get(
+ &self,
+ request: Request<Ticket>,
+ ) -> Result<Response<Self::DoGetStream>, Status> {
+ let ticket = request.into_inner();
+
+ let action =
+ decode_protobuf(&ticket.ticket).map_err(|e|
from_ballista_err(&e))?;
+
+ match &action {
+ BallistaAction::FetchPartition {
+ host, port, job_id, ..
+ } => {
+ debug!("Fetching results for job id: {job_id} from
{host}:{port}");
+ let mut client = get_flight_client(host, port)
+ .map_err(|e| from_ballista_err(&e))
+ .await?;
+ client
+ .do_get(Request::new(ticket))
+ .await
+ .map(|r| Response::new(Box::pin(r.into_inner()) as
Self::DoGetStream))
+ }
+ }
+ }
+
+ async fn do_put(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoPutStream>, Status> {
+ Err(Status::unimplemented("do_put"))
+ }
+
+ async fn do_exchange(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoExchangeStream>, Status> {
+ Err(Status::unimplemented("do_exchange"))
+ }
+
+ async fn do_action(
+ &self,
+ _request: Request<Action>,
+ ) -> Result<Response<Self::DoActionStream>, Status> {
+ Err(Status::unimplemented("do_action"))
+ }
+
+ async fn list_actions(
+ &self,
+ _request: Request<Empty>,
+ ) -> Result<Response<Self::ListActionsStream>, Status> {
+ Err(Status::unimplemented("list_actions"))
+ }
+}
+
+fn from_ballista_err(e: &ballista_core::error::BallistaError) -> Status {
+ Status::internal(format!("Ballista Error: {e:?}"))
+}
+
+async fn get_flight_client(
+ host: &String,
+ port: &u16,
+) -> Result<FlightServiceClient<tonic::transport::channel::Channel>,
BallistaError> {
Review Comment:
Consider connection pooling. Currently a new client is created for each
request.
##########
ballista/core/proto/ballista.proto:
##########
@@ -707,6 +707,12 @@ message RunningTaskInfo {
uint32 partition_id = 4;;
Review Comment:
```suggestion
uint32 partition_id = 4;
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]