This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e4842ce Only decode plan in `LaunchMultiTaskParams`  once (#743)
4e4842ce is described below

commit 4e4842ce5221b8ce6ce39b82bb1346e337129b0d
Author: Daniël Heres <[email protected]>
AuthorDate: Thu Apr 13 14:02:35 2023 +0200

    Only decode plan in `LaunchMultiTaskParams`  once (#743)
    
    * Only decode plan once
    
    * WIP
    
    * WIP
    
    * Clippy
    
    * Refactor
    
    * Refactor
    
    * Refactor
    
    * Refactor
    
    * Fmt
    
    * clippy
    
    * clippy
    
    * Cleanup
    
    * Reuse querystage exec
    
    * Fmt
    
    * Cleanup
    
    ---------
    
    Co-authored-by: Daniël Heres <[email protected]>
---
 ballista/core/src/serde/scheduler/from_proto.rs |  72 +++++----
 ballista/executor/src/execution_engine.rs       |   7 +
 ballista/executor/src/executor_server.rs        | 201 +++++++++++++++++-------
 3 files changed, 187 insertions(+), 93 deletions(-)

diff --git a/ballista/core/src/serde/scheduler/from_proto.rs 
b/ballista/core/src/serde/scheduler/from_proto.rs
index 6c008abb..ec39ba3a 100644
--- a/ballista/core/src/serde/scheduler/from_proto.rs
+++ b/ballista/core/src/serde/scheduler/from_proto.rs
@@ -269,35 +269,38 @@ impl Into<ExecutorData> for protobuf::ExecutorData {
     }
 }
 
-impl TryInto<TaskDefinition> for protobuf::TaskDefinition {
+impl TryInto<(TaskDefinition, Vec<u8>)> for protobuf::TaskDefinition {
     type Error = BallistaError;
 
-    fn try_into(self) -> Result<TaskDefinition, Self::Error> {
+    fn try_into(self) -> Result<(TaskDefinition, Vec<u8>), Self::Error> {
         let mut props = HashMap::new();
         for kv_pair in self.props {
             props.insert(kv_pair.key, kv_pair.value);
         }
 
-        Ok(TaskDefinition {
-            task_id: self.task_id as usize,
-            task_attempt_num: self.task_attempt_num as usize,
-            job_id: self.job_id,
-            stage_id: self.stage_id as usize,
-            stage_attempt_num: self.stage_attempt_num as usize,
-            partition_id: self.partition_id as usize,
-            plan: self.plan,
-            output_partitioning: self.output_partitioning,
-            session_id: self.session_id,
-            launch_time: self.launch_time,
-            props,
-        })
+        Ok((
+            TaskDefinition {
+                task_id: self.task_id as usize,
+                task_attempt_num: self.task_attempt_num as usize,
+                job_id: self.job_id,
+                stage_id: self.stage_id as usize,
+                stage_attempt_num: self.stage_attempt_num as usize,
+                partition_id: self.partition_id as usize,
+                plan: vec![],
+                output_partitioning: self.output_partitioning,
+                session_id: self.session_id,
+                launch_time: self.launch_time,
+                props,
+            },
+            self.plan,
+        ))
     }
 }
 
-impl TryInto<Vec<TaskDefinition>> for protobuf::MultiTaskDefinition {
+impl TryInto<(Vec<TaskDefinition>, Vec<u8>)> for protobuf::MultiTaskDefinition 
{
     type Error = BallistaError;
 
-    fn try_into(self) -> Result<Vec<TaskDefinition>, Self::Error> {
+    fn try_into(self) -> Result<(Vec<TaskDefinition>, Vec<u8>), Self::Error> {
         let mut props = HashMap::new();
         for kv_pair in self.props {
             props.insert(kv_pair.key, kv_pair.value);
@@ -312,21 +315,24 @@ impl TryInto<Vec<TaskDefinition>> for 
protobuf::MultiTaskDefinition {
         let launch_time = self.launch_time;
         let task_ids = self.task_ids;
 
-        Ok(task_ids
-            .iter()
-            .map(|task_id| TaskDefinition {
-                task_id: task_id.task_id as usize,
-                task_attempt_num: task_id.task_attempt_num as usize,
-                job_id: job_id.clone(),
-                stage_id,
-                stage_attempt_num,
-                partition_id: task_id.partition_id as usize,
-                plan: plan.clone(),
-                output_partitioning: output_partitioning.clone(),
-                session_id: session_id.clone(),
-                launch_time,
-                props: props.clone(),
-            })
-            .collect())
+        Ok((
+            task_ids
+                .iter()
+                .map(|task_id| TaskDefinition {
+                    task_id: task_id.task_id as usize,
+                    task_attempt_num: task_id.task_attempt_num as usize,
+                    job_id: job_id.clone(),
+                    stage_id,
+                    stage_attempt_num,
+                    partition_id: task_id.partition_id as usize,
+                    plan: vec![],
+                    output_partitioning: output_partitioning.clone(),
+                    session_id: session_id.clone(),
+                    launch_time,
+                    props: props.clone(),
+                })
+                .collect(),
+            plan,
+        ))
     }
 }
diff --git a/ballista/executor/src/execution_engine.rs 
b/ballista/executor/src/execution_engine.rs
index d62176a9..c75d4743 100644
--- a/ballista/executor/src/execution_engine.rs
+++ b/ballista/executor/src/execution_engine.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use arrow::datatypes::SchemaRef;
 use async_trait::async_trait;
 use ballista_core::execution_plans::ShuffleWriterExec;
 use ballista_core::serde::protobuf::ShuffleWritePartition;
@@ -51,6 +52,8 @@ pub trait QueryStageExecutor: Sync + Send + Debug {
     ) -> Result<Vec<ShuffleWritePartition>>;
 
     fn collect_plan_metrics(&self) -> Vec<MetricsSet>;
+
+    fn schema(&self) -> SchemaRef;
 }
 
 pub struct DefaultExecutionEngine {}
@@ -108,6 +111,10 @@ impl QueryStageExecutor for DefaultQueryStageExec {
             .await
     }
 
+    fn schema(&self) -> SchemaRef {
+        self.shuffle_writer.schema()
+    }
+
     fn collect_plan_metrics(&self) -> Vec<MetricsSet> {
         utils::collect_plan_metrics(self.shuffle_writer.children()[0].as_ref())
     }
diff --git a/ballista/executor/src/executor_server.rs 
b/ballista/executor/src/executor_server.rs
index 26dec059..ff8e24eb 100644
--- a/ballista/executor/src/executor_server.rs
+++ b/ballista/executor/src/executor_server.rs
@@ -47,7 +47,6 @@ use ballista_core::serde::BallistaCodec;
 use ballista_core::utils::{create_grpc_client_connection, create_grpc_server};
 use dashmap::DashMap;
 use datafusion::execution::context::TaskContext;
-use datafusion::physical_plan::ExecutionPlan;
 use datafusion_proto::{
     logical_plan::AsLogicalPlan,
     physical_plan::{from_proto::parse_protobuf_hash_partitioning, 
AsExecutionPlan},
@@ -56,6 +55,7 @@ use tokio::sync::mpsc::error::TryRecvError;
 use tokio::task::JoinHandle;
 
 use crate::cpu_bound_executor::DedicatedExecutor;
+use crate::execution_engine::QueryStageExecutor;
 use crate::executor::Executor;
 use crate::shutdown::ShutdownNotifier;
 use crate::{as_task_status, TaskExecutionTimes};
@@ -67,7 +67,8 @@ type SchedulerClients = Arc<DashMap<String, 
SchedulerGrpcClient<Channel>>>;
 #[derive(Debug)]
 struct CuratorTaskDefinition {
     scheduler_id: String,
-    task: TaskDefinition,
+    plan: Vec<u8>,
+    tasks: Vec<TaskDefinition>,
 }
 
 /// Wrap TaskStatus with its curator scheduler id for task update to its 
specific curator scheduler later
@@ -296,17 +297,67 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
         }
     }
 
+    async fn decode_task(
+        &self,
+        curator_task: TaskDefinition,
+        plan: &[u8],
+    ) -> Result<Arc<dyn QueryStageExecutor>, BallistaError> {
+        let task = curator_task;
+        let task_identity = task_identity(&task);
+        let task_props = task.props;
+        let mut config = ConfigOptions::new();
+        for (k, v) in task_props {
+            config.set(&k, &v)?;
+        }
+        let session_config = SessionConfig::from(config);
+
+        let mut task_scalar_functions = HashMap::new();
+        let mut task_aggregate_functions = HashMap::new();
+        for scalar_func in self.executor.scalar_functions.clone() {
+            task_scalar_functions.insert(scalar_func.0, scalar_func.1);
+        }
+        for agg_func in self.executor.aggregate_functions.clone() {
+            task_aggregate_functions.insert(agg_func.0, agg_func.1);
+        }
+
+        let task_context = Arc::new(TaskContext::new(
+            Some(task_identity),
+            task.session_id.clone(),
+            session_config,
+            task_scalar_functions,
+            task_aggregate_functions,
+            self.executor.runtime.clone(),
+        ));
+
+        let plan = U::try_decode(plan).and_then(|proto| {
+            proto.try_into_physical_plan(
+                task_context.deref(),
+                &self.executor.runtime,
+                self.codec.physical_extension_codec(),
+            )
+        })?;
+
+        Ok(self.executor.execution_engine.create_query_stage_exec(
+            task.job_id,
+            task.stage_id,
+            plan,
+            &self.executor.work_dir,
+        )?)
+    }
+
     async fn run_task(
         &self,
-        task_identity: String,
-        curator_task: CuratorTaskDefinition,
+        task_identity: &str,
+        scheduler_id: String,
+        curator_task: TaskDefinition,
+        query_stage_exec: Arc<dyn QueryStageExecutor>,
     ) -> Result<(), BallistaError> {
         let start_exec_time = SystemTime::now()
             .duration_since(UNIX_EPOCH)
             .unwrap()
             .as_millis() as u64;
         info!("Start to run task {}", task_identity);
-        let task = curator_task.task;
+        let task = curator_task;
         let task_props = task.props;
         let mut config = ConfigOptions::new();
         for (k, v) in task_props {
@@ -327,7 +378,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
         let session_id = task.session_id;
         let runtime = self.executor.runtime.clone();
         let task_context = Arc::new(TaskContext::new(
-            Some(task_identity.clone()),
+            Some(task_identity.to_string()),
             session_id,
             session_config,
             task_scalar_functions,
@@ -335,21 +386,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
             runtime.clone(),
         ));
 
-        let encoded_plan = &task.plan.as_slice();
-
-        let plan: Arc<dyn ExecutionPlan> =
-            U::try_decode(encoded_plan).and_then(|proto| {
-                proto.try_into_physical_plan(
-                    task_context.deref(),
-                    runtime.deref(),
-                    self.codec.physical_extension_codec(),
-                )
-            })?;
-
         let shuffle_output_partitioning = parse_protobuf_hash_partitioning(
             task.output_partitioning.as_ref(),
             task_context.as_ref(),
-            plan.schema().as_ref(),
+            query_stage_exec.schema().as_ref(),
         )?;
 
         let task_id = task.task_id;
@@ -357,12 +397,6 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
         let stage_id = task.stage_id;
         let stage_attempt_num = task.stage_attempt_num;
         let partition_id = task.partition_id;
-        let query_stage_exec = 
self.executor.execution_engine.create_query_stage_exec(
-            job_id.clone(),
-            stage_id,
-            plan,
-            &self.executor.work_dir,
-        )?;
 
         let part = PartitionId {
             job_id: job_id.clone(),
@@ -412,7 +446,6 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
             task_execution_times,
         );
 
-        let scheduler_id = curator_task.scheduler_id;
         let task_status_sender = self.executor_env.tx_task_status.clone();
         task_status_sender
             .send(CuratorTaskStatus {
@@ -474,6 +507,18 @@ struct TaskRunnerPool<T: 'static + AsLogicalPlan, U: 
'static + AsExecutionPlan>
     executor_server: Arc<ExecutorServer<T, U>>,
 }
 
+fn task_identity(task: &TaskDefinition) -> String {
+    format!(
+        "TID {} {}/{}.{}/{}.{}",
+        &task.task_id,
+        &task.job_id,
+        &task.stage_id,
+        &task.stage_attempt_num,
+        &task.partition_id,
+        &task.task_attempt_num,
+    )
+}
+
 impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> 
TaskRunnerPool<T, U> {
     fn new(executor_server: Arc<ExecutorServer<T, U>>) -> Self {
         Self { executor_server }
@@ -595,30 +640,64 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskRunnerPool<T,
                         return;
                     }
                 };
-                if let Some(curator_task) = maybe_task {
-                    let task_identity = format!(
-                        "TID {} {}/{}.{}/{}.{}",
-                        &curator_task.task.task_id,
-                        &curator_task.task.job_id,
-                        &curator_task.task.stage_id,
-                        &curator_task.task.stage_attempt_num,
-                        &curator_task.task.partition_id,
-                        &curator_task.task.task_attempt_num,
-                    );
-                    info!("Received task {:?}", &task_identity);
-
+                if let Some(task) = maybe_task {
                     let server = executor_server.clone();
-                    dedicated_executor.spawn(async move {
-                        server
-                            .run_task(task_identity.clone(), curator_task)
-                            .await
-                            .unwrap_or_else(|e| {
-                                error!(
-                                    "Fail to run the task {:?} due to {:?}",
-                                    task_identity, e
-                                );
-                            });
+                    let plan = task.plan;
+                    let curator_task = task.tasks[0].clone();
+                    let out: tokio::sync::oneshot::Receiver<
+                        Result<Arc<dyn QueryStageExecutor>, BallistaError>,
+                    > = dedicated_executor.spawn(async move {
+                        server.decode_task(curator_task, &plan).await
                     });
+
+                    let plan = out.await;
+
+                    let plan = match plan {
+                        Ok(Ok(plan)) => plan,
+                        Ok(Err(e)) => {
+                            error!(
+                                "Failed to decode the plan of task {:?} due to 
{:?}",
+                                task_identity(&task.tasks[0]),
+                                e
+                            );
+                            return;
+                        }
+                        Err(e) => {
+                            error!(
+                                "Failed to receive error plan of task {:?} due 
to {:?}",
+                                task_identity(&task.tasks[0]),
+                                e
+                            );
+                            return;
+                        }
+                    };
+                    let scheduler_id = task.scheduler_id.clone();
+
+                    for curator_task in task.tasks {
+                        let plan = plan.clone();
+                        let scheduler_id = scheduler_id.clone();
+
+                        let task_identity = task_identity(&curator_task);
+                        info!("Received task {:?}", &task_identity);
+
+                        let server = executor_server.clone();
+                        dedicated_executor.spawn(async move {
+                            server
+                                .run_task(
+                                    &task_identity,
+                                    scheduler_id,
+                                    curator_task,
+                                    plan,
+                                )
+                                .await
+                                .unwrap_or_else(|e| {
+                                    error!(
+                                        "Fail to run the task {:?} due to 
{:?}",
+                                        task_identity, e
+                                    );
+                                });
+                        });
+                    }
                 } else {
                     info!("Channel is closed and will exit the task receive 
loop");
                     drop(task_runner_complete);
@@ -643,12 +722,15 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorGrpc
         } = request.into_inner();
         let task_sender = self.executor_env.tx_task.clone();
         for task in tasks {
+            let (task_def, plan) = task
+                .try_into()
+                .map_err(|e| Status::invalid_argument(format!("{e}")))?;
+
             task_sender
                 .send(CuratorTaskDefinition {
                     scheduler_id: scheduler_id.clone(),
-                    task: task
-                        .try_into()
-                        .map_err(|e| 
Status::invalid_argument(format!("{e}")))?,
+                    plan,
+                    tasks: vec![task_def],
                 })
                 .await
                 .unwrap();
@@ -668,18 +750,17 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorGrpc
         } = request.into_inner();
         let task_sender = self.executor_env.tx_task.clone();
         for multi_task in multi_tasks {
-            let multi_task: Vec<TaskDefinition> = multi_task
+            let (multi_task, plan): (Vec<TaskDefinition>, Vec<u8>) = multi_task
                 .try_into()
                 .map_err(|e| Status::invalid_argument(format!("{e}")))?;
-            for task in multi_task {
-                task_sender
-                    .send(CuratorTaskDefinition {
-                        scheduler_id: scheduler_id.clone(),
-                        task,
-                    })
-                    .await
-                    .unwrap();
-            }
+            task_sender
+                .send(CuratorTaskDefinition {
+                    scheduler_id: scheduler_id.clone(),
+                    plan,
+                    tasks: multi_task,
+                })
+                .await
+                .unwrap();
         }
         Ok(Response::new(LaunchMultiTaskResult { success: true }))
     }

Reply via email to