This is an automated email from the ASF dual-hosted git repository.
milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git
The following commit(s) were added to refs/heads/main by this push:
new 5e98229e3 feat: job scheduling with push based job status updates
(#1478)
5e98229e3 is described below
commit 5e98229e3ad1eb58f395c79e0a0dc9e01df43df0
Author: Marko Milenković <[email protected]>
AuthorDate: Mon Mar 2 17:05:16 2026 +0000
feat: job scheduling with push based job status updates (#1478)
* implement push based job execution
* minor cleanup
* add additional test
* refactor, extract common code to methods
* fix job name issue
* clone subscriber, not to keep awaiting in a lock
* addressing few comments
* remove print
* update sender to use try send
* fix clippy
---
ballista/client/tests/context_checks.rs | 58 +++++
ballista/core/proto/ballista.proto | 2 +
ballista/core/src/config.rs | 13 +-
.../core/src/execution_plans/distributed_query.rs | 240 ++++++++++++++++---
ballista/core/src/lib.rs | 5 +
ballista/core/src/serde/generated/ballista.rs | 89 +++++++
ballista/scheduler/src/cluster/memory.rs | 64 +++++-
ballista/scheduler/src/cluster/mod.rs | 9 +-
ballista/scheduler/src/cluster/test_util/mod.rs | 2 +-
ballista/scheduler/src/scheduler_server/event.rs | 4 +-
ballista/scheduler/src/scheduler_server/grpc.rs | 255 ++++++++++++++-------
ballista/scheduler/src/scheduler_server/mod.rs | 223 +++++++++++++++++-
.../src/scheduler_server/query_stage_scheduler.rs | 37 ++-
ballista/scheduler/src/state/mod.rs | 3 +
ballista/scheduler/src/state/task_manager.rs | 8 +-
ballista/scheduler/src/test_utils.rs | 17 +-
16 files changed, 897 insertions(+), 132 deletions(-)
diff --git a/ballista/client/tests/context_checks.rs
b/ballista/client/tests/context_checks.rs
index 908e9c23d..0faf3b9b7 100644
--- a/ballista/client/tests/context_checks.rs
+++ b/ballista/client/tests/context_checks.rs
@@ -1047,4 +1047,62 @@ mod supported {
];
assert_batches_eq!(expected, &result);
}
+
+ #[rstest]
+ #[case::standalone(standalone_context())]
+ #[case::remote(remote_context())]
+ #[tokio::test]
+ async fn should_force_client_pull(
+ #[future(awt)]
+ #[case]
+ ctx: SessionContext,
+ test_data: String,
+ ) -> datafusion::error::Result<()> {
+ ctx.register_parquet(
+ "test",
+ &format!("{test_data}/alltypes_plain.parquet"),
+ Default::default(),
+ )
+ .await?;
+
+ ctx.sql("SET ballista.client.pull = true")
+ .await?
+ .show()
+ .await?;
+
+ let result = ctx
+ .sql("select name, value from information_schema.df_settings where
name like 'ballista.client.pull' order by name limit 1")
+ .await?
+ .collect()
+ .await?;
+
+ let expected = [
+ "+----------------------+-------+",
+ "| name | value |",
+ "+----------------------+-------+",
+ "| ballista.client.pull | true |",
+ "+----------------------+-------+",
+ ];
+
+ assert_batches_eq!(expected, &result);
+
+ let expected = [
+ "+------------+----------+",
+ "| string_col | count(*) |",
+ "+------------+----------+",
+ "| 30 | 1 |",
+ "| 31 | 2 |",
+ "+------------+----------+",
+ ];
+
+ let result = ctx
+ .sql("select string_col, count(*) from test where id > 4 group by
string_col order by string_col")
+ .await?
+ .collect()
+ .await?;
+
+ assert_batches_eq!(expected, &result);
+
+ Ok(())
+ }
}
diff --git a/ballista/core/proto/ballista.proto
b/ballista/core/proto/ballista.proto
index 97812bb0d..641aa9aac 100644
--- a/ballista/core/proto/ballista.proto
+++ b/ballista/core/proto/ballista.proto
@@ -759,6 +759,8 @@ service SchedulerGrpc {
rpc RemoveSession (RemoveSessionParams) returns (RemoveSessionResult) {}
+ rpc ExecuteQueryPush (ExecuteQueryParams) returns (stream
GetJobStatusResult) {}
+
rpc ExecuteQuery (ExecuteQueryParams) returns (ExecuteQueryResult) {}
rpc GetJobStatus (GetJobStatusParams) returns (GetJobStatusResult) {}
diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs
index 15e031a16..ca151558c 100644
--- a/ballista/core/src/config.rs
+++ b/ballista/core/src/config.rs
@@ -80,6 +80,8 @@ pub const BALLISTA_SHUFFLE_SORT_BASED_SPILL_THRESHOLD: &str =
/// Configuration key for sort shuffle target batch size in rows.
pub const BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE: &str =
"ballista.shuffle.sort_based.batch_size";
+/// Should client employ pull or push job tracking strategy
+pub const BALLISTA_CLIENT_PULL: &str = "ballista.client.pull";
/// Result type for configuration parsing operations.
pub type ParseResult<T> = result::Result<T, String>;
@@ -156,7 +158,11 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String,
ConfigEntry>> = LazyLock::new(||
ConfigEntry::new(BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE.to_string(),
"Target batch size in rows for coalescing small
batches in sort shuffle".to_string(),
DataType::UInt64,
- Some((8192).to_string()))
+ Some((8192).to_string())),
+ ConfigEntry::new(BALLISTA_CLIENT_PULL.to_string(),
+ "Should client employ pull or push job tracking. In
pull mode client will make a request to server in the loop, until job finishes.
Pull mode is kept for legacy clients.".to_string(),
+ DataType::Boolean,
+ Some(false.to_string()))
];
entries
.into_iter()
@@ -362,6 +368,11 @@ impl BallistaConfig {
self.get_usize_setting(BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE)
}
+ /// Should client employ pull or push job tracking strategy
+ pub fn client_pull(&self) -> bool {
+ self.get_bool_setting(BALLISTA_CLIENT_PULL)
+ }
+
fn get_usize_setting(&self, key: &str) -> usize {
if let Some(v) = self.settings.get(key) {
// infallible because we validate all configs in the constructor
diff --git a/ballista/core/src/execution_plans/distributed_query.rs
b/ballista/core/src/execution_plans/distributed_query.rs
index 8fb6935c4..5f1ad258d 100644
--- a/ballista/core/src/execution_plans/distributed_query.rs
+++ b/ballista/core/src/execution_plans/distributed_query.rs
@@ -247,33 +247,62 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for
DistributedQueryExec<T> {
let session_config = context.session_config().clone();
- let stream = futures::stream::once(
- execute_query(
- self.scheduler_url.clone(),
- self.session_id.clone(),
- query,
- self.config.default_grpc_client_max_message_size(),
- GrpcClientConfig::from(&self.config),
- Arc::new(self.metrics.clone()),
- partition,
- session_config,
+ if session_config.ballista_config().client_pull() {
+ let stream = futures::stream::once(
+ execute_query_pull(
+ self.scheduler_url.clone(),
+ self.session_id.clone(),
+ query,
+ self.config.default_grpc_client_max_message_size(),
+ GrpcClientConfig::from(&self.config),
+ Arc::new(self.metrics.clone()),
+ partition,
+ session_config,
+ )
+ .map_err(|e| ArrowError::ExternalError(Box::new(e))),
)
- .map_err(|e| ArrowError::ExternalError(Box::new(e))),
- )
- .try_flatten()
- .inspect(move |batch| {
- metric_total_bytes.add(
- batch
- .as_ref()
- .map(|b| b.get_array_memory_size())
- .unwrap_or(0),
- );
-
- metric_row_count.add(batch.as_ref().map(|b|
b.num_rows()).unwrap_or(0));
- });
-
- let schema = self.schema();
- Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
+ .try_flatten()
+ .inspect(move |batch| {
+ metric_total_bytes.add(
+ batch
+ .as_ref()
+ .map(|b| b.get_array_memory_size())
+ .unwrap_or(0),
+ );
+
+ metric_row_count.add(batch.as_ref().map(|b|
b.num_rows()).unwrap_or(0));
+ });
+
+ let schema = self.schema();
+ Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
+ } else {
+ let stream = futures::stream::once(
+ execute_query_push(
+ self.scheduler_url.clone(),
+ query,
+ self.config.default_grpc_client_max_message_size(),
+ GrpcClientConfig::from(&self.config),
+ Arc::new(self.metrics.clone()),
+ partition,
+ session_config,
+ )
+ .map_err(|e| ArrowError::ExternalError(Box::new(e))),
+ )
+ .try_flatten()
+ .inspect(move |batch| {
+ metric_total_bytes.add(
+ batch
+ .as_ref()
+ .map(|b| b.get_array_memory_size())
+ .unwrap_or(0),
+ );
+
+ metric_row_count.add(batch.as_ref().map(|b|
b.num_rows()).unwrap_or(0));
+ });
+
+ let schema = self.schema();
+ Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
+ }
}
fn statistics(&self) -> Result<Statistics> {
@@ -288,8 +317,11 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for
DistributedQueryExec<T> {
}
}
+/// Client will periodically invoke scheduler to check
+/// job status. There is preconfigured wait period between
+/// pulls, which increases query latency.
#[allow(clippy::too_many_arguments)]
-async fn execute_query(
+async fn execute_query_pull(
scheduler_url: String,
session_id: String,
query: ExecuteQueryParams,
@@ -453,6 +485,160 @@ async fn execute_query(
};
}
}
+/// After job is scheduled client waits
+/// for job updates, which are streamed back
+/// from server to client
+#[allow(clippy::too_many_arguments)]
+async fn execute_query_push(
+ scheduler_url: String,
+ query: ExecuteQueryParams,
+ max_message_size: usize,
+ grpc_config: GrpcClientConfig,
+ metrics: Arc<ExecutionPlanMetricsSet>,
+ partition: usize,
+ session_config: SessionConfig,
+) -> Result<impl Stream<Item = Result<RecordBatch>> + Send> {
+ let grpc_interceptor = session_config.ballista_grpc_interceptor();
+ let customize_endpoint =
+ session_config.ballista_override_create_grpc_client_endpoint();
+ let use_tls = session_config.ballista_use_tls();
+
+ // Capture query submission time for total_query_time_ms
+ let query_start_time = std::time::Instant::now();
+
+ info!("Connecting to Ballista scheduler at {scheduler_url}");
+ // TODO reuse the scheduler to avoid connecting to the Ballista scheduler
again and again
+ let mut endpoint =
+ create_grpc_client_endpoint(scheduler_url.clone(), Some(&grpc_config))
+ .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
+
+ if let Some(ref customize) = customize_endpoint {
+ endpoint = customize
+ .configure_endpoint(endpoint)
+ .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
+ }
+
+ let connection = endpoint
+ .connect()
+ .await
+ .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
+
+ let mut scheduler = SchedulerGrpcClient::with_interceptor(
+ connection,
+ grpc_interceptor.as_ref().clone(),
+ )
+ .max_encoding_message_size(max_message_size)
+ .max_decoding_message_size(max_message_size);
+
+ let mut query_status_stream = scheduler
+ .execute_query_push(query)
+ .await
+ .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
+ .into_inner();
+
+ let mut prev_status: Option<job_status::Status> = None;
+
+ loop {
+ let item = query_status_stream
+ .next()
+ .await
+ .ok_or(DataFusionError::Execution(
+ "Stream closed without job completing".to_string(),
+ ))?
+ .map_err(|e| DataFusionError::Execution(e.to_string()))?;
+
+ let GetJobStatusResult {
+ status,
+ flight_proxy,
+ } = item;
+ let job_id = status
+ .as_ref()
+ .map(|s| s.job_id.to_owned())
+ .unwrap_or("unknown_job_id".to_string()); // should not happen
+ let status = status.and_then(|s| s.status);
+ let has_status_change = prev_status != status;
+ match status {
+ None => {
+ if has_status_change {
+ info!("Job {job_id} is in initialization ...");
+ }
+ prev_status = status;
+ }
+ Some(job_status::Status::Queued(_)) => {
+ if has_status_change {
+ info!("Job {job_id} is queued...");
+ }
+ prev_status = status;
+ }
+ Some(job_status::Status::Running(_)) => {
+ if has_status_change {
+ info!("Job {job_id} is running...");
+ }
+ prev_status = status;
+ }
+ Some(job_status::Status::Failed(err)) => {
+ let msg = format!("Job {} failed: {}", job_id, err.error);
+ error!("{msg}");
+ break Err(DataFusionError::Execution(msg));
+ }
+ Some(job_status::Status::Successful(SuccessfulJob {
+ queued_at,
+ started_at,
+ ended_at,
+ partition_location,
+ ..
+ })) => {
+ // Calculate job execution time (server-side execution)
+ let job_execution_ms = ended_at.saturating_sub(started_at);
+ let duration = Duration::from_millis(job_execution_ms);
+
+ info!("Job {job_id} finished executing in {duration:?} ");
+
+ // Calculate scheduling time (server-side queue time)
+ // This includes network latency and actual queue time
+ let scheduling_ms = started_at.saturating_sub(queued_at);
+
+ // Calculate total query time (end-to-end from client
perspective)
+ let total_elapsed = query_start_time.elapsed();
+ let total_ms = total_elapsed.as_millis();
+
+ // Set timing metrics
+ let metric_job_execution = MetricBuilder::new(&metrics)
+ .gauge("job_execution_time_ms", partition);
+ metric_job_execution.set(job_execution_ms as usize);
+
+ let metric_scheduling =
+ MetricBuilder::new(&metrics).gauge("job_scheduling_in_ms",
partition);
+ metric_scheduling.set(scheduling_ms as usize);
+
+ let metric_total_time =
+ MetricBuilder::new(&metrics).gauge("total_query_time_ms",
partition);
+ metric_total_time.set(total_ms as usize);
+
+ // Note: data_transfer_time_ms is not set here because
partition fetching
+ // happens lazily when the stream is consumed, not during
execute_query.
+ // This could be added in a future enhancement by wrapping the
stream.
+
+ let streams = partition_location.into_iter().map(move
|partition| {
+ let f = fetch_partition(
+ partition,
+ max_message_size,
+ true,
+ scheduler_url.clone(),
+ flight_proxy.clone(),
+ customize_endpoint.clone(),
+ use_tls,
+ )
+ .map_err(|e| ArrowError::ExternalError(Box::new(e)));
+
+ futures::stream::once(f).try_flatten()
+ });
+
+ break Ok(futures::stream::iter(streams).flatten());
+ }
+ };
+ }
+}
fn get_client_host_port(
executor_metadata: &ExecutorMetadata,
diff --git a/ballista/core/src/lib.rs b/ballista/core/src/lib.rs
index 2055e723f..c9c4ef1d5 100644
--- a/ballista/core/src/lib.rs
+++ b/ballista/core/src/lib.rs
@@ -21,6 +21,8 @@
use std::sync::Arc;
use datafusion::{execution::runtime_env::RuntimeEnv, prelude::SessionConfig};
+
+use crate::serde::protobuf::JobStatus;
/// The current version of Ballista, derived from the Cargo package version.
pub const BALLISTA_VERSION: &str = env!("CARGO_PKG_VERSION");
@@ -76,3 +78,6 @@ pub type RuntimeProducer = Arc<
/// It is intended to be used with executor configuration
///
pub type ConfigProducer = Arc<dyn Fn() -> SessionConfig + Send + Sync>;
+
+/// Job Notification Subscriber
+pub type JobStatusSubscriber = tokio::sync::mpsc::Sender<JobStatus>;
diff --git a/ballista/core/src/serde/generated/ballista.rs
b/ballista/core/src/serde/generated/ballista.rs
index adeeaa8a5..1cfe102e5 100644
--- a/ballista/core/src/serde/generated/ballista.rs
+++ b/ballista/core/src/serde/generated/ballista.rs
@@ -1386,6 +1386,35 @@ pub mod scheduler_grpc_client {
);
self.inner.unary(req, path, codec).await
}
+ pub async fn execute_query_push(
+ &mut self,
+ request: impl tonic::IntoRequest<super::ExecuteQueryParams>,
+ ) -> std::result::Result<
+
tonic::Response<tonic::codec::Streaming<super::GetJobStatusResult>>,
+ tonic::Status,
+ > {
+ self.inner
+ .ready()
+ .await
+ .map_err(|e| {
+ tonic::Status::unknown(
+ format!("Service was not ready: {}", e.into()),
+ )
+ })?;
+ let codec = tonic_prost::ProstCodec::default();
+ let path = http::uri::PathAndQuery::from_static(
+ "/ballista.protobuf.SchedulerGrpc/ExecuteQueryPush",
+ );
+ let mut req = request.into_request();
+ req.extensions_mut()
+ .insert(
+ GrpcMethod::new(
+ "ballista.protobuf.SchedulerGrpc",
+ "ExecuteQueryPush",
+ ),
+ );
+ self.inner.server_streaming(req, path, codec).await
+ }
pub async fn execute_query(
&mut self,
request: impl tonic::IntoRequest<super::ExecuteQueryParams>,
@@ -1569,6 +1598,19 @@ pub mod scheduler_grpc_server {
tonic::Response<super::RemoveSessionResult>,
tonic::Status,
>;
+ /// Server streaming response type for the ExecuteQueryPush method.
+ type ExecuteQueryPushStream: tonic::codegen::tokio_stream::Stream<
+ Item = std::result::Result<super::GetJobStatusResult,
tonic::Status>,
+ >
+ + std::marker::Send
+ + 'static;
+ async fn execute_query_push(
+ &self,
+ request: tonic::Request<super::ExecuteQueryParams>,
+ ) -> std::result::Result<
+ tonic::Response<Self::ExecuteQueryPushStream>,
+ tonic::Status,
+ >;
async fn execute_query(
&self,
request: tonic::Request<super::ExecuteQueryParams>,
@@ -1956,6 +1998,53 @@ pub mod scheduler_grpc_server {
};
Box::pin(fut)
}
+ "/ballista.protobuf.SchedulerGrpc/ExecuteQueryPush" => {
+ #[allow(non_camel_case_types)]
+ struct ExecuteQueryPushSvc<T: SchedulerGrpc>(pub Arc<T>);
+ impl<
+ T: SchedulerGrpc,
+ >
tonic::server::ServerStreamingService<super::ExecuteQueryParams>
+ for ExecuteQueryPushSvc<T> {
+ type Response = super::GetJobStatusResult;
+ type ResponseStream = T::ExecuteQueryPushStream;
+ type Future = BoxFuture<
+ tonic::Response<Self::ResponseStream>,
+ tonic::Status,
+ >;
+ fn call(
+ &mut self,
+ request: tonic::Request<super::ExecuteQueryParams>,
+ ) -> Self::Future {
+ let inner = Arc::clone(&self.0);
+ let fut = async move {
+ <T as
SchedulerGrpc>::execute_query_push(&inner, request)
+ .await
+ };
+ Box::pin(fut)
+ }
+ }
+ let accept_compression_encodings =
self.accept_compression_encodings;
+ let send_compression_encodings =
self.send_compression_encodings;
+ let max_decoding_message_size =
self.max_decoding_message_size;
+ let max_encoding_message_size =
self.max_encoding_message_size;
+ let inner = self.inner.clone();
+ let fut = async move {
+ let method = ExecuteQueryPushSvc(inner);
+ let codec = tonic_prost::ProstCodec::default();
+ let mut grpc = tonic::server::Grpc::new(codec)
+ .apply_compression_config(
+ accept_compression_encodings,
+ send_compression_encodings,
+ )
+ .apply_max_message_size_config(
+ max_decoding_message_size,
+ max_encoding_message_size,
+ );
+ let res = grpc.server_streaming(method, req).await;
+ Ok(res)
+ };
+ Box::pin(fut)
+ }
"/ballista.protobuf.SchedulerGrpc/ExecuteQuery" => {
#[allow(non_camel_case_types)]
struct ExecuteQuerySvc<T: SchedulerGrpc>(pub Arc<T>);
diff --git a/ballista/scheduler/src/cluster/memory.rs
b/ballista/scheduler/src/cluster/memory.rs
index 87ef6709f..f3f70e99c 100644
--- a/ballista/scheduler/src/cluster/memory.rs
+++ b/ballista/scheduler/src/cluster/memory.rs
@@ -23,22 +23,23 @@ use crate::cluster::{
};
use crate::state::execution_graph::ExecutionGraphBox;
use async_trait::async_trait;
-use ballista_core::ConfigProducer;
use ballista_core::error::{BallistaError, Result};
use ballista_core::serde::protobuf::{
AvailableTaskSlots, ExecutorHeartbeat, ExecutorStatus, FailedJob,
QueuedJob,
executor_status,
};
use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
+use ballista_core::{ConfigProducer, JobStatusSubscriber};
use dashmap::DashMap;
use datafusion::prelude::{SessionConfig, SessionContext};
+use tokio::sync::mpsc::error::TrySendError;
use crate::cluster::event::ClusterEventSender;
use crate::scheduler_server::{SessionBuilder, timestamp_millis,
timestamp_secs};
use crate::state::session_manager::create_datafusion_context;
use crate::state::task_manager::JobInfoCache;
use ballista_core::serde::protobuf::job_status::Status;
-use log::{debug, error, info, warn};
+use log::{error, info, warn};
use std::collections::{HashMap, HashSet};
use std::ops::DerefMut;
@@ -351,7 +352,7 @@ pub struct InMemoryJobState {
/// In-memory store of queued jobs. Map from Job ID -> (Job Name,
queued_at timestamp)
queued_jobs: DashMap<String, (String, u64)>,
/// In-memory store of running job statuses. Map from Job ID -> JobStatus
- running_jobs: DashMap<String, JobStatus>,
+ running_jobs: DashMap<String, ExtendedJobStatus>,
/// `SessionBuilder` for building DataFusion `SessionContext` from
`BallistaConfig`
session_builder: SessionBuilder,
/// Sender of job events
@@ -380,12 +381,42 @@ impl InMemoryJobState {
}
}
+#[derive(Clone)]
+struct ExtendedJobStatus {
+ status: JobStatus,
+ subscriber: Option<JobStatusSubscriber>,
+}
+
+impl ExtendedJobStatus {
+ fn update_subscribers(&self, status: JobStatus) {
+ let job_id = status.job_id.clone();
+ if let Some(subscriber) = &self.subscriber
+ && matches!(subscriber.try_send(status),
Err(TrySendError::Full(_)))
+ {
+ // to be considered if we need another task to try to push this
notification
+ // at the moment, it does not look as necessary as, buffer should
be big enough for all cases
+ error!(
+ "jobs notification subscriber for job {} is blocked, can't
deliver status update, job notification will be missed",
+ job_id
+ )
+ }
+ }
+}
+
#[async_trait]
impl JobState for InMemoryJobState {
- async fn submit_job(&self, job_id: String, graph: &ExecutionGraphBox) ->
Result<()> {
+ async fn submit_job(
+ &self,
+ job_id: String,
+ graph: &ExecutionGraphBox,
+ subscriber: Option<JobStatusSubscriber>,
+ ) -> Result<()> {
if self.queued_jobs.get(&job_id).is_some() {
- self.running_jobs
- .insert(job_id.clone(), graph.status().clone());
+ let status = ExtendedJobStatus {
+ status: graph.status().clone(),
+ subscriber,
+ };
+ self.running_jobs.insert(job_id.clone(), status);
self.queued_jobs.remove(&job_id);
self.job_event_sender.send(&JobStateEvent::JobAcquired {
@@ -413,7 +444,7 @@ impl JobState for InMemoryJobState {
}
if let Some(status) =
self.running_jobs.get(job_id).as_deref().cloned() {
- return Ok(Some(status));
+ return Ok(Some(status.status));
}
if let Some((status, _)) = self.completed_jobs.get(job_id).as_deref() {
@@ -442,20 +473,29 @@ impl JobState for InMemoryJobState {
async fn save_job(&self, job_id: &str, graph: &ExecutionGraphBox) ->
Result<()> {
let status = graph.status().clone();
-
- debug!("saving state for job {job_id} with status {:?}", status);
-
// If job is either successful or failed, save to completed jobs
if matches!(
status.status,
Some(Status::Successful(_)) | Some(Status::Failed(_))
) {
+ if let Some((_, job_info)) = self.running_jobs.remove(job_id) {
+ job_info.update_subscribers(status.clone());
+ }
+
self.completed_jobs
.insert(job_id.to_string(), (status.clone(),
Some(graph.cloned())));
- self.running_jobs.remove(job_id);
} else {
// otherwise update running job
- self.running_jobs.insert(job_id.to_string(), status.clone());
+ if let Some(mut job_info) = self.running_jobs.get_mut(job_id) {
+ job_info.status = status.clone();
+ // we're cloning subscriber not to await in lock
+ job_info.update_subscribers(status.clone());
+ } else {
+ Err(BallistaError::Internal(format!(
+ "scheduler state can't find job: {}",
+ job_id
+ )))?
+ };
}
// job change event emitted
diff --git a/ballista/scheduler/src/cluster/mod.rs
b/ballista/scheduler/src/cluster/mod.rs
index 5d01091cc..5647ff89a 100644
--- a/ballista/scheduler/src/cluster/mod.rs
+++ b/ballista/scheduler/src/cluster/mod.rs
@@ -37,7 +37,7 @@ use ballista_core::serde::protobuf::{
};
use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata,
PartitionId};
use ballista_core::utils::{default_config_producer, default_session_builder};
-use ballista_core::{ConfigProducer, consistent_hash};
+use ballista_core::{ConfigProducer, JobStatusSubscriber, consistent_hash};
use crate::cluster::memory::{InMemoryClusterState, InMemoryJobState};
@@ -301,7 +301,12 @@ pub trait JobState: Send + Sync {
/// Submits a new job to the job state.
///
/// The submitter is assumed to own the job.
- async fn submit_job(&self, job_id: String, graph: &ExecutionGraphBox) ->
Result<()>;
+ async fn submit_job(
+ &self,
+ job_id: String,
+ graph: &ExecutionGraphBox,
+ subscriber: Option<JobStatusSubscriber>,
+ ) -> Result<()>;
/// Returns the set of all active job IDs.
async fn get_jobs(&self) -> Result<HashSet<String>>;
diff --git a/ballista/scheduler/src/cluster/test_util/mod.rs
b/ballista/scheduler/src/cluster/test_util/mod.rs
index d9e960004..71693930e 100644
--- a/ballista/scheduler/src/cluster/test_util/mod.rs
+++ b/ballista/scheduler/src/cluster/test_util/mod.rs
@@ -89,7 +89,7 @@ impl<S: JobState> JobStateTest<S> {
/// Submits a job with the given execution graph.
pub async fn submit_job(self, graph: &ExecutionGraphBox) -> Result<Self> {
self.state
- .submit_job(graph.job_id().to_string(), graph)
+ .submit_job(graph.job_id().to_string(), graph, None)
.await?;
Ok(self)
}
diff --git a/ballista/scheduler/src/scheduler_server/event.rs
b/ballista/scheduler/src/scheduler_server/event.rs
index 77ef7a65e..c6d11fb1b 100644
--- a/ballista/scheduler/src/scheduler_server/event.rs
+++ b/ballista/scheduler/src/scheduler_server/event.rs
@@ -20,7 +20,7 @@ use std::fmt::{Debug, Formatter};
use datafusion::logical_expr::LogicalPlan;
use crate::state::execution_graph::RunningTaskInfo;
-use ballista_core::serde::protobuf::TaskStatus;
+use ballista_core::{JobStatusSubscriber, serde::protobuf::TaskStatus};
use datafusion::prelude::SessionContext;
use std::sync::Arc;
@@ -39,6 +39,8 @@ pub enum QueryStageSchedulerEvent {
plan: Box<LogicalPlan>,
/// Timestamp when the job was queued.
queued_at: u64,
+ /// job status subscriber
+ subscriber: Option<JobStatusSubscriber>,
},
/// A job has been submitted for execution.
JobSubmitted {
diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs
b/ballista/scheduler/src/scheduler_server/grpc.rs
index 072b8a4d3..4e1e9e5b1 100644
--- a/ballista/scheduler/src/scheduler_server/grpc.rs
+++ b/ballista/scheduler/src/scheduler_server/grpc.rs
@@ -17,6 +17,7 @@
use axum::extract::ConnectInfo;
use ballista_core::config::BALLISTA_JOB_NAME;
+use ballista_core::error::{BallistaError, Result as BResult};
use ballista_core::extension::SessionConfigHelperExt;
use ballista_core::serde::protobuf::execute_query_params::Query;
use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc;
@@ -26,16 +27,20 @@ use ballista_core::serde::protobuf::{
ExecuteQueryFailureResult, ExecuteQueryParams, ExecuteQueryResult,
ExecuteQuerySuccessResult, ExecutorHeartbeat, ExecutorStoppedParams,
ExecutorStoppedResult, GetJobStatusParams, GetJobStatusResult,
HeartBeatParams,
- HeartBeatResult, PollWorkParams, PollWorkResult, RegisterExecutorParams,
- RegisterExecutorResult, RemoveSessionParams, RemoveSessionResult,
- UpdateTaskStatusParams, UpdateTaskStatusResult,
execute_query_failure_result,
- execute_query_result,
+ HeartBeatResult, JobStatus, KeyValuePair, PollWorkParams, PollWorkResult,
+ RegisterExecutorParams, RegisterExecutorResult, RemoveSessionParams,
+ RemoveSessionResult, UpdateTaskStatusParams, UpdateTaskStatusResult,
+ execute_query_failure_result, execute_query_result,
};
use ballista_core::serde::scheduler::ExecutorMetadata;
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
+use futures::{Stream, StreamExt};
use log::{debug, error, info, trace, warn};
use std::net::SocketAddr;
+use std::pin::Pin;
+
+use tokio_stream::wrappers::ReceiverStream;
#[cfg(feature = "substrait")]
use {
@@ -43,12 +48,14 @@ use {
datafusion_substrait::serializer::deserialize_bytes,
};
-use std::ops::Deref;
-
use crate::cluster::{bind_task_bias, bind_task_round_robin};
use crate::config::TaskDistributionPolicy;
use crate::scheduler_server::event::QueryStageSchedulerEvent;
use ballista_core::serde::protobuf::get_job_status_result::FlightProxy;
+use datafusion::logical_expr::LogicalPlan;
+use datafusion::prelude::SessionContext;
+use std::ops::Deref;
+use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tonic::{Request, Response, Status};
@@ -58,6 +65,9 @@ use crate::scheduler_server::SchedulerServer;
impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
for SchedulerServer<T, U>
{
+ type ExecuteQueryPushStream =
+ Pin<Box<dyn Stream<Item = Result<GetJobStatusResult, Status>> + Send>>;
+
async fn poll_work(
&self,
request: Request<PollWorkParams>,
@@ -333,10 +343,11 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerGrpc
Ok(Response::new(RemoveSessionResult { success: true }))
}
- async fn execute_query(
+ async fn execute_query_push(
&self,
- request: Request<ExecuteQueryParams>,
- ) -> Result<Response<ExecuteQueryResult>, Status> {
+ request: tonic::Request<ExecuteQueryParams>,
+ ) -> std::result::Result<tonic::Response<Self::ExecuteQueryPushStream>,
tonic::Status>
+ {
let query_params = request.into_inner();
if let ExecuteQueryParams {
query: Some(query),
@@ -351,44 +362,99 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerGrpc
.and_then(|s| s.value.clone())
.unwrap_or_default();
- let job_id = self.state.task_manager.generate_job_id();
+ info!(
+ "execution query (PUSH) job received - session_id:
{session_id}, operation_id: {operation_id}, job_name: {job_name}"
+ );
+
+ let (session_id, session_ctx) = self
+ .create_context(&settings, session_id)
+ .await
+ .map_err(|e| {
+ Status::internal(format!("Failed to create SessionContext:
{e:?}"))
+ })?;
+
+ let plan = self.parse_plan(query, &session_ctx).await.map_err(|e| {
+ let msg = format!("Could not parse plan: {e}");
+ error!("{}", msg);
+
+ Status::invalid_argument(msg)
+ })?;
+
+ debug!(
+ "Decoded logical plan for execution:\n{}",
+ plan.display_indent()
+ );
+ log::trace!("setting job name: {job_name}");
+
+ let flight_proxy = self.flight_proxy_config();
+
+ let (subscriber, rx) = tokio::sync::mpsc::channel::<JobStatus>(16);
+ let stream = ReceiverStream::new(rx).map(move |status| {
+ Ok::<_, tonic::Status>(GetJobStatusResult {
+ status: Some(status),
+ flight_proxy: flight_proxy.clone(),
+ })
+ });
+
+ let job_id = self
+ .submit_job(&job_name, session_ctx, &plan, Some(subscriber))
+ .await
+ .map_err(|e| {
+ let msg =
+ format!("Failed to send JobQueued event for
{job_name}: {e:?}");
+
+ error!("{msg}");
+
+ Status::internal(msg)
+ })?;
info!(
- "execution query - session_id: {session_id}, operation_id:
{operation_id}, job_name: {job_name}, job_id: {job_id}"
+ "execution query (PUSH) job submitted - session_id:
{session_id}, operation_id: {operation_id}, job_name: {job_name}, job_id:
{job_id}"
);
- let (session_id, session_ctx) = {
- let session_config =
self.state.session_manager.produce_config();
- let session_config =
session_config.update_from_key_value_pair(&settings);
+ Ok(Response::new(Box::pin(stream)))
+ } else {
+ Err(Status::internal(
+ "Error processing request, invalid message",
+ ))
+ }
+ }
- let ctx = self
- .state
- .session_manager
- .create_or_update_session(&session_id, &session_config)
- .await
- .map_err(|e| {
- Status::internal(format!(
- "Failed to create SessionContext: {e:?}"
- ))
- })?;
+ async fn execute_query(
+ &self,
+ request: Request<ExecuteQueryParams>,
+ ) -> Result<Response<ExecuteQueryResult>, Status> {
+ let query_params = request.into_inner();
+ if let ExecuteQueryParams {
+ query: Some(query),
+ session_id,
+ operation_id,
+ settings,
+ } = query_params
+ {
+ let job_name = settings
+ .iter()
+ .find(|s| s.key == BALLISTA_JOB_NAME)
+ .and_then(|s| s.value.clone())
+ .unwrap_or_default();
- (session_id, ctx)
- };
+ info!(
+ "execution query job received - session_id: {session_id},
operation_id: {operation_id}, job_name: {job_name}"
+ );
+
+ let (session_id, session_ctx) = self
+ .create_context(&settings, session_id)
+ .await
+ .map_err(|e| {
+ Status::internal(format!("Failed to create SessionContext:
{e:?}"))
+ })?;
- let plan = match query {
- Query::LogicalPlan(message) => {
- match T::try_decode(message.as_slice()).and_then(|m| {
- m.try_into_logical_plan(
- session_ctx.task_ctx().deref(),
- self.state.codec.logical_extension_codec(),
- )
- }) {
- Ok(plan) => plan,
- Err(e) => {
- let msg =
- format!("Could not parse logical plan
protobuf: {e}");
- error!("{msg}");
- return Ok(Response::new(ExecuteQueryResult {
+ let plan = match self.parse_plan(query, &session_ctx).await {
+ Ok(plan) => plan,
+ Err(e) => {
+ let msg = format!("Could not parse plan: {e}");
+ error!("{msg}");
+ return Ok(Response::new(ExecuteQueryResult {
operation_id,
result:
Some(execute_query_result::Result::Failure(
ExecuteQueryFailureResult {
@@ -396,38 +462,6 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerGrpc
},
)),
}));
- }
- }
- }
- #[cfg(not(feature = "substrait"))]
- Query::SubstraitPlan(_) => {
- let msg = "Received query type \"Substrait\", enable
\"substrait\" feature to support Substrait plans.".to_string();
- error!("{msg}");
- return Ok(Response::new(ExecuteQueryResult {
- operation_id,
- result: Some(execute_query_result::Result::Failure(
- ExecuteQueryFailureResult {
- failure:
Some(execute_query_failure_result::Failure::PlanParsingFailure(msg)),
- }
- ))
- }));
- }
- #[cfg(feature = "substrait")]
- Query::SubstraitPlan(bytes) => {
- let plan = deserialize_bytes(bytes).await.map_err(|e| {
- let msg = format!("Could not parse substrait plan:
{e}");
- error!("{}", msg);
- Status::internal(msg)
- })?;
-
- let ctx = session_ctx.as_ref().clone();
- from_substrait_plan(&ctx.state(), &plan)
- .await
- .map_err(|e| {
- let msg = format!("Could not parse substrait plan:
{e}");
- error!("{}", msg);
- Status::internal(msg)
- })?
}
};
@@ -438,7 +472,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerGrpc
log::trace!("setting job name: {job_name}");
let job_id = self
- .submit_job(&job_name, session_ctx, &plan)
+ .submit_job(&job_name, session_ctx, &plan, None)
.await
.map_err(|e| {
let msg =
@@ -448,6 +482,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerGrpc
Status::internal(msg)
})?;
+ info!(
+ "execution query, job submitted - session_id: {session_id},
operation_id: {operation_id}, job_name: {job_name}"
+ );
+
Ok(Response::new(ExecuteQueryResult {
operation_id,
result: Some(execute_query_result::Result::Success(
@@ -466,15 +504,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerGrpc
let job_id = request.into_inner().job_id;
trace!("Received get_job_status request for job {}", job_id);
- let flight_proxy =
- self.state
- .config
- .advertise_flight_sql_endpoint
- .clone()
- .map(|s| match s {
- s if s.is_empty() => FlightProxy::Local(true),
- s => FlightProxy::External(s),
- });
+ let flight_proxy = self.flight_proxy_config();
match self.state.task_manager.get_job_status(&job_id).await {
Ok(status) => Ok(Response::new(GetJobStatusResult {
@@ -569,6 +599,67 @@ fn extract_connect_info<T>(request: &Request<T>) ->
Option<ConnectInfo<SocketAdd
.cloned()
}
+impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
SchedulerServer<T, U> {
+ async fn create_context(
+ &self,
+ settings: &[KeyValuePair],
+ session_id: String,
+ ) -> BResult<(String, Arc<SessionContext>)> {
+ let session_config = self.state.session_manager.produce_config();
+ let session_config =
session_config.update_from_key_value_pair(settings);
+
+ let ctx = self
+ .state
+ .session_manager
+ .create_or_update_session(&session_id, &session_config)
+ .await?;
+
+ Ok((session_id, ctx))
+ }
+
+ async fn parse_plan(
+ &self,
+ query: Query,
+ session_ctx: &SessionContext,
+ ) -> BResult<LogicalPlan> {
+ match query {
+ Query::LogicalPlan(message) => T::try_decode(message.as_slice())
+ .and_then(|m| {
+ m.try_into_logical_plan(
+ session_ctx.task_ctx().deref(),
+ self.state.codec.logical_extension_codec(),
+ )
+ })
+ .map_err(|e| e.into()),
+
+ #[cfg(not(feature = "substrait"))]
+ Query::SubstraitPlan(_) => {
+ Err(BallistaError::NotImplemented("Received query type
\"Substrait\", enable \"substrait\" feature to support Substrait
plans.".to_string()))
+ }
+ #[cfg(feature = "substrait")]
+ Query::SubstraitPlan(bytes) => {
+ let plan = deserialize_bytes(bytes).await.map_err(|e|
BallistaError::DataFusionError(e.into()))?;
+
+ let ctx = session_ctx.clone();
+ from_substrait_plan(&ctx.state(), &plan)
+ .await
+ .map_err(|e| e.into())
+ }
+ }
+ }
+
+ fn flight_proxy_config(&self) -> Option<FlightProxy> {
+ self.state
+ .config
+ .advertise_flight_sql_endpoint
+ .clone()
+ .map(|s| match s {
+ s if s.is_empty() => FlightProxy::Local(true),
+ s => FlightProxy::External(s),
+ })
+ }
+}
+
#[cfg(test)]
mod test {
use std::sync::Arc;
diff --git a/ballista/scheduler/src/scheduler_server/mod.rs
b/ballista/scheduler/src/scheduler_server/mod.rs
index 94b34a1dc..8a0fbec2b 100644
--- a/ballista/scheduler/src/scheduler_server/mod.rs
+++ b/ballista/scheduler/src/scheduler_server/mod.rs
@@ -18,6 +18,7 @@
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
+use ballista_core::JobStatusSubscriber;
use ballista_core::error::Result;
use ballista_core::event_loop::{EventLoop, EventSender};
use ballista_core::serde::BallistaCodec;
@@ -222,10 +223,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerServer<T
job_name: &str,
ctx: Arc<SessionContext>,
plan: &LogicalPlan,
+ subscriber: Option<JobStatusSubscriber>,
) -> Result<String> {
log::debug!("Received submit request for job {job_name}");
let job_id = self.state.task_manager.generate_job_id();
-
self.query_stage_event_loop
.get_sender()?
.post_event(QueryStageSchedulerEvent::JobQueued {
@@ -234,6 +235,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerServer<T
session_ctx: ctx,
plan: Box::new(plan.clone()),
queued_at: timestamp_millis(),
+ subscriber,
})
.await?;
@@ -410,6 +412,7 @@ mod test {
use std::sync::Arc;
use ballista_core::extension::SessionConfigExt;
+ use ballista_core::serde::protobuf::job_status::Status;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::functions_aggregate::sum::sum;
use datafusion::logical_expr::{LogicalPlan, col};
@@ -480,7 +483,7 @@ mod test {
// Submit job
scheduler
.state
- .submit_job(job_id, "", ctx, &plan, 0)
+ .submit_job(job_id, "", ctx, &plan, 0, None)
.await
.expect("submitting plan");
@@ -590,6 +593,64 @@ mod test {
Ok(())
}
+ // checks if job subscriber is getting same events
+ #[tokio::test]
+ async fn test_push_scheduling_with_subscriber() -> Result<()> {
+ let plan = test_plan();
+ // this test will fail when AQE scheduling is used.
+ // as AQE will fold plan due to empty scan
+ let metrics_collector = Arc::new(TestMetricsCollector::default());
+
+ let mut test = SchedulerTest::new(
+ SchedulerConfig::default()
+ .with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
+ metrics_collector.clone(),
+ 4,
+ 1,
+ None,
+ )
+ .await?;
+ let (tx, mut rx) = tokio::sync::mpsc::channel(16);
+
+ let (status, job_id) = test
+ .run_with_subscriber("", &plan, Some(tx))
+ .await
+ .expect("running plan");
+
+ match status.status {
+ Some(job_status::Status::Successful(SuccessfulJob {
+ partition_location,
+ ..
+ })) => {
+ assert_eq!(partition_location.len(), 4);
+ }
+ other => {
+ panic!("Expected success status but found {other:?}");
+ }
+ }
+
+ assert_submitted_event(&job_id, &metrics_collector);
+ assert_completed_event(&job_id, &metrics_collector);
+
+ let mut buffer = vec![];
+ rx.recv_many(&mut buffer, 16).await;
+ assert!(!buffer.is_empty());
+
+ let successful_job = buffer
+ .iter()
+ .find(|s| matches!(s.status, Some(Status::Successful(_))));
+
+ assert!(successful_job.is_some());
+
+ let failed_job = buffer
+ .iter()
+ .find(|s| matches!(s.status, Some(Status::Failed(_))));
+
+ assert!(failed_job.is_none());
+
+ Ok(())
+ }
+
// Simulate a task failure and ensure the job status is updated correctly
#[tokio::test]
async fn test_job_failure() -> Result<()> {
@@ -665,6 +726,102 @@ mod test {
Ok(())
}
+ // Simulate a task failure and ensure the job status is updated correctly
+ // it also checks if job subscriber is getting same events
+ #[tokio::test]
+ async fn test_job_failure_subscriber() -> Result<()> {
+ let plan = test_plan();
+
+ let runner = Arc::new(TaskRunnerFn::new(
+ |_executor_id: String, task: MultiTaskDefinition| {
+ let mut statuses = vec![];
+
+ for TaskId {
+ task_id,
+ partition_id,
+ ..
+ } in task.task_ids
+ {
+ let timestamp = timestamp_millis();
+ statuses.push(TaskStatus {
+ task_id,
+ job_id: task.job_id.clone(),
+ stage_id: task.stage_id,
+ stage_attempt_num: task.stage_attempt_num,
+ partition_id,
+ launch_time: timestamp,
+ start_exec_time: timestamp,
+ end_exec_time: timestamp,
+ metrics: vec![],
+ status: Some(task_status::Status::Failed(FailedTask {
+ error: "ERROR".to_string(),
+ retryable: false,
+ count_to_failures: false,
+ failed_reason: Some(
+ failed_task::FailedReason::ExecutionError(
+ ExecutionError {},
+ ),
+ ),
+ })),
+ });
+ }
+
+ statuses
+ },
+ ));
+
+ let metrics_collector = Arc::new(TestMetricsCollector::default());
+
+ let mut test = SchedulerTest::new(
+ SchedulerConfig::default()
+ .with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
+ metrics_collector.clone(),
+ 4,
+ 1,
+ Some(runner),
+ )
+ .await?;
+ let (tx, mut rx) = tokio::sync::mpsc::channel(16);
+ let (status, job_id) = test
+ .run_with_subscriber("", &plan, Some(tx))
+ .await
+ .expect("running plan");
+
+ assert!(
+ matches!(
+ status,
+ JobStatus {
+ status: Some(job_status::Status::Failed(_)),
+ ..
+ }
+ ),
+ "{}",
+ "Expected job status to be failed but it was {status:?}"
+ );
+
+ assert_submitted_event(&job_id, &metrics_collector);
+ assert_failed_event(&job_id, &metrics_collector);
+
+ let mut buffer = vec![];
+ rx.recv_many(&mut buffer, 16).await;
+
+ assert!(!buffer.is_empty());
+
+ let failed_job = buffer
+ .iter()
+ .find(|s| matches!(s.status, Some(Status::Failed(_))));
+
+ assert!(failed_job.is_some());
+
+ let successful_job = buffer
+ .iter()
+ .find(|s| matches!(s.status, Some(Status::Successful(_))));
+
+ assert!(successful_job.is_none());
+
+ Ok(())
+ }
+
// If the physical planning fails, the job should be marked as failed.
// Here we simulate a planning failure using ExplodingTableProvider to
test this.
#[tokio::test]
@@ -710,6 +867,68 @@ mod test {
Ok(())
}
+ // If the physical planning fails, the job should be marked as failed.
+ // Here we simulate a planning failure using ExplodingTableProvider to
test this.
+ // it also checks if job subscriber is getting same events
+ #[tokio::test]
+ async fn test_planning_failure_with_subscriber() -> Result<()> {
+ let metrics_collector = Arc::new(TestMetricsCollector::default());
+ let mut test = SchedulerTest::new(
+ SchedulerConfig::default()
+ .with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
+ metrics_collector.clone(),
+ 4,
+ 1,
+ None,
+ )
+ .await?;
+
+ let ctx = test.ctx().await?;
+
+ ctx.register_table("explode", Arc::new(ExplodingTableProvider))?;
+
+ let plan = ctx
+ .sql("SELECT * FROM explode")
+ .await?
+ .into_optimized_plan()?;
+ let (tx, mut rx) = tokio::sync::mpsc::channel(16);
+ // This should fail when we try and create the physical plan
+ let (status, job_id) = test.run_with_subscriber("", &plan,
Some(tx)).await?;
+
+ assert!(
+ matches!(
+ status,
+ JobStatus {
+ status: Some(job_status::Status::Failed(_)),
+ ..
+ }
+ ),
+ "{}",
+ "Expected job status to be failed but it was {status:?}"
+ );
+
+ assert_no_submitted_event(&job_id, &metrics_collector);
+ assert_failed_event(&job_id, &metrics_collector);
+
+ let mut buffer = vec![];
+ rx.recv_many(&mut buffer, 16).await;
+ assert!(!buffer.is_empty());
+
+ let failed_job = buffer
+ .iter()
+ .find(|s| matches!(s.status, Some(Status::Failed(_))));
+
+ assert!(failed_job.is_some());
+
+ let successful_job = buffer
+ .iter()
+ .find(|s| matches!(s.status, Some(Status::Successful(_))));
+
+ assert!(successful_job.is_none());
+
+ Ok(())
+ }
+
async fn test_scheduler(
scheduling_policy: TaskSchedulingPolicy,
) -> Result<SchedulerServer<LogicalPlanNode, PhysicalPlanNode>> {
diff --git a/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs
b/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs
index 97e7f8569..85b0924e2 100644
--- a/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs
+++ b/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs
@@ -19,10 +19,12 @@ use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
+use ballista_core::serde::protobuf::{FailedJob, JobStatus};
use log::{error, info, trace, warn};
use ballista_core::error::{BallistaError, Result};
use ballista_core::event_loop::{EventAction, EventSender};
+use tokio::sync::mpsc::error::TrySendError;
use crate::config::SchedulerConfig;
use crate::metrics::SchedulerMetricsCollector;
@@ -93,6 +95,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
session_ctx,
plan,
queued_at,
+ subscriber,
} => {
info!("Job {job_id} queued with name {job_name:?}");
@@ -108,10 +111,42 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan>
let state = self.state.clone();
tokio::spawn(async move {
let event = if let Err(e) = state
- .submit_job(&job_id, &job_name, session_ctx, &plan,
queued_at)
+ .submit_job(
+ &job_id,
+ &job_name,
+ session_ctx,
+ &plan,
+ queued_at,
+ subscriber.clone(),
+ )
.await
{
+ let error = e.to_string();
let fail_message = format!("Error planning job
{job_id}: {e:?}");
+
+ // this is a corner case, as most of job status
changes are handled in
+ // job state, after job is submitted to job state
+ if let Some(subscriber) = subscriber {
+ let timestamp = timestamp_millis();
+ let job_status = JobStatus {
+ job_id: job_id.clone(),
+ job_name,
+ status:
Some(ballista_core::serde::protobuf::job_status::Status::Failed(
+ FailedJob { error, queued_at, started_at:
timestamp, ended_at: timestamp }
+ ))
+ };
+
+ if matches!(
+ subscriber.try_send(job_status),
+ Err(TrySendError::Full(_))
+ ) {
+ error!(
+ "jobs notification subscriber for job {}
is blocked, can't deliver status update, job notification will be missed",
+ job_id
+ )
+ }
+ }
+
error!("{}", &fail_message);
QueryStageSchedulerEvent::JobPlanningFailed {
job_id,
diff --git a/ballista/scheduler/src/state/mod.rs
b/ballista/scheduler/src/state/mod.rs
index 49e473c58..82fa60c97 100644
--- a/ballista/scheduler/src/state/mod.rs
+++ b/ballista/scheduler/src/state/mod.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use ballista_core::JobStatusSubscriber;
use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion::datasource::listing::{ListingTable, ListingTableUrl};
use datafusion::datasource::source_as_provider;
@@ -379,6 +380,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerState<T,
session_ctx: Arc<SessionContext>,
plan: &LogicalPlan,
queued_at: u64,
+ subscriber: Option<JobStatusSubscriber>,
) -> Result<()> {
let start = Instant::now();
let session_config = Arc::new(session_ctx.copied_config());
@@ -487,6 +489,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerState<T,
plan.data,
queued_at,
session_config,
+ subscriber,
)
.await?;
diff --git a/ballista/scheduler/src/state/task_manager.rs
b/ballista/scheduler/src/state/task_manager.rs
index ae3686901..b55806407 100644
--- a/ballista/scheduler/src/state/task_manager.rs
+++ b/ballista/scheduler/src/state/task_manager.rs
@@ -23,6 +23,7 @@ use crate::state::execution_graph::{
};
use crate::state::executor_manager::ExecutorManager;
+use ballista_core::JobStatusSubscriber;
use ballista_core::error::BallistaError;
use ballista_core::error::Result;
use ballista_core::extension::{SessionConfigExt, SessionConfigHelperExt};
@@ -156,6 +157,7 @@ impl JobInfoCache {
/// Creates a new `JobInfoCache` from an execution graph.
pub fn new(graph: ExecutionGraphBox) -> Self {
let status = graph.status().status.clone();
+
Self {
execution_graph: Arc::new(RwLock::new(graph)),
status,
@@ -266,6 +268,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> TaskManager<T, U>
/// Generate an ExecutionGraph for the job and save it to the persistent
state.
/// By default, this job will be curated by the scheduler which receives
it.
/// Then we will also save it to the active execution graph
+ #[allow(clippy::too_many_arguments)]
pub async fn submit_job(
&self,
job_id: &str,
@@ -274,6 +277,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> TaskManager<T, U>
plan: Arc<dyn ExecutionPlan>,
queued_at: u64,
session_config: Arc<SessionConfig>,
+ subscriber: Option<JobStatusSubscriber>,
) -> Result<()> {
let mut planner = DefaultDistributedPlanner::new();
@@ -307,7 +311,9 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> TaskManager<T, U>
info!("Submitting execution graph:\n\n{graph:?}");
- self.state.submit_job(job_id.to_string(), &graph).await?;
+ self.state
+ .submit_job(job_id.to_string(), &graph, subscriber)
+ .await?;
graph.revive();
self.active_job_cache
.insert(job_id.to_owned(), JobInfoCache::new(graph));
diff --git a/ballista/scheduler/src/test_utils.rs
b/ballista/scheduler/src/test_utils.rs
index 7ac6f83a9..1d4f3633f 100644
--- a/ballista/scheduler/src/test_utils.rs
+++ b/ballista/scheduler/src/test_utils.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use ballista_core::JobStatusSubscriber;
use ballista_core::error::{BallistaError, Result};
use ballista_core::extension::SessionConfigExt;
use datafusion::catalog::Session;
@@ -507,7 +508,7 @@ impl SchedulerTest {
.create_or_update_session("session_id", &self.session_config)
.await?;
- let job_id = self.scheduler.submit_job(job_name, ctx, plan).await?;
+ let job_id = self.scheduler.submit_job(job_name, ctx, plan,
None).await?;
Ok(job_id)
}
@@ -627,6 +628,15 @@ impl SchedulerTest {
&mut self,
job_name: &str,
plan: &LogicalPlan,
+ ) -> Result<(JobStatus, String)> {
+ self.run_with_subscriber(job_name, plan, None).await
+ }
+ /// Returns job status and job_id, with provided subscriber
+ pub async fn run_with_subscriber(
+ &mut self,
+ job_name: &str,
+ plan: &LogicalPlan,
+ subscriber: Option<JobStatusSubscriber>,
) -> Result<(JobStatus, String)> {
let ctx = self
.scheduler
@@ -635,7 +645,10 @@ impl SchedulerTest {
.create_or_update_session("session_id", &self.session_config)
.await?;
- let job_id = self.scheduler.submit_job(job_name, ctx, plan).await?;
+ let job_id = self
+ .scheduler
+ .submit_job(job_name, ctx, plan, subscriber)
+ .await?;
let mut receiver = self.status_receiver.take().unwrap();
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]