This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f223db2 Upgrade to Datafusion 41 (#1062)
2f223db2 is described below

commit 2f223db21557c15080bf865ac692d276b8f0b770
Author: Baris Palaska <[email protected]>
AuthorDate: Fri Sep 27 19:16:15 2024 +0100

    Upgrade to Datafusion 41 (#1062)
    
    * upgrade dependencies, some fixes, still wip
    
    compiles
    
    add license
    
    Update datafusion protobuf definitions (#1057)
    
    * update datafusion proto defs
    
    * allow optionals in proto3
    
    update docker environment for higher protoc version
    
    * runs
    
    * test e2e, fix python
    
    * rm unnecessary dependency
    
    * create BallistaLogicalExtensionCodec that can decode/encode file formats, 
fix some tests
    
    * fix tests
    
    * clippy, tomlfmt
    
    * fix grpc connect info extract
    
    * extract into method, remove unnecessary log
    
    * datafusion to 41, adjust other deps
---
 Cargo.toml                                         |  26 +-
 ballista-cli/Cargo.toml                            |   4 +-
 ballista-cli/src/main.rs                           |  28 +-
 ballista/client/README.md                          |   3 +-
 ballista/client/src/context.rs                     |  32 +--
 ballista/core/Cargo.toml                           |   6 +-
 ballista/core/src/config.rs                        |  14 +-
 .../core/src/execution_plans/distributed_query.rs  |   4 +
 .../core/src/execution_plans/shuffle_reader.rs     |   4 +
 .../core/src/execution_plans/shuffle_writer.rs     |   4 +
 .../core/src/execution_plans/unresolved_shuffle.rs |   4 +
 ballista/core/src/serde/mod.rs                     | 104 +++++++-
 ballista/core/src/serde/scheduler/mod.rs           |   5 +
 ballista/core/src/utils.rs                         |  27 +-
 ballista/executor/Cargo.toml                       |   1 -
 ballista/executor/src/collect.rs                   |   4 +
 ballista/executor/src/executor.rs                  |  22 +-
 ballista/scheduler/Cargo.toml                      |  17 +-
 ballista/scheduler/src/api/handlers.rs             | 297 ++++++++++++---------
 ballista/scheduler/src/api/mod.rs                  | 153 +++--------
 ballista/scheduler/src/bin/main.rs                 |  13 +-
 ballista/scheduler/src/cluster/mod.rs              |  26 +-
 ballista/scheduler/src/config.rs                   |   6 +-
 ballista/scheduler/src/planner.rs                  |   2 +
 ballista/scheduler/src/scheduler_process.rs        |  85 ++----
 ballista/scheduler/src/scheduler_server/grpc.rs    |  15 +-
 ballista/scheduler/src/scheduler_server/mod.rs     |   6 +-
 ballista/scheduler/src/state/execution_graph.rs    | 140 ++++------
 .../scheduler/src/state/execution_graph_dot.rs     |  17 +-
 ballista/scheduler/src/test_utils.rs               |  73 ++++-
 examples/Cargo.toml                                |   4 +-
 python/.cargo/config.toml                          |  11 +
 python/Cargo.toml                                  |  11 +-
 python/src/context.rs                              |  25 +-
 python/src/lib.rs                                  |   3 +-
 35 files changed, 653 insertions(+), 543 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 3c451885..ea1c8321 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -21,19 +21,23 @@ members = ["ballista-cli", "ballista/cache", 
"ballista/client", "ballista/core",
 resolver = "2"
 
 [workspace.dependencies]
-arrow = { version = "52.0.0", features = ["ipc_compression"] }
-arrow-flight = { version = "52.0.0", features = ["flight-sql-experimental"] }
-arrow-schema = { version = "52.0.0", default-features = false }
+arrow = { version = "52.2.0", features = ["ipc_compression"] }
+arrow-flight = { version = "52.2.0", features = ["flight-sql-experimental"] }
+arrow-schema = { version = "52.2.0", default-features = false }
+clap = { version = "3", features = ["derive", "cargo"] }
 configure_me = { version = "0.4.0" }
 configure_me_codegen = { version = "0.4.4" }
-datafusion = "39.0.0"
-datafusion-cli = "39.0.0"
-datafusion-proto = "39.0.0"
-datafusion-proto-common = "39.0.0"
-object_store = "0.10.1"
-sqlparser = "0.47.0"
-tonic = { version = "0.11" }
-tonic-build = { version = "0.11", default-features = false, features = [
+# bump directly to datafusion v43 to avoid the serde bug on v42 
(https://github.com/apache/datafusion/pull/12626)
+datafusion = "41.0.0"
+datafusion-cli = "41.0.0"
+datafusion-proto = "41.0.0"
+datafusion-proto-common = "41.0.0"
+object_store = "0.10.2"
+prost = "0.12.0"
+prost-types = "0.12.0"
+sqlparser = "0.49.0"
+tonic = { version = "0.11.0" }
+tonic-build = { version = "0.11.0", default-features = false, features = [
     "transport",
     "prost"
 ] }
diff --git a/ballista-cli/Cargo.toml b/ballista-cli/Cargo.toml
index e07ad279..dc8ff7cb 100644
--- a/ballista-cli/Cargo.toml
+++ b/ballista-cli/Cargo.toml
@@ -30,14 +30,14 @@ readme = "README.md"
 
 [dependencies]
 ballista = { path = "../ballista/client", version = "0.12.0", features = 
["standalone"] }
-clap = { version = "3", features = ["derive", "cargo"] }
+clap = { workspace = true }
 datafusion = { workspace = true }
 datafusion-cli = { workspace = true }
 dirs = "5.0.1"
 env_logger = "0.10"
 mimalloc = { version = "0.1", default-features = false }
 num_cpus = "1.13.0"
-rustyline = "11.0"
+rustyline = "11.0.0"
 tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", 
"sync", "parking_lot"] }
 
 [features]
diff --git a/ballista-cli/src/main.rs b/ballista-cli/src/main.rs
index a8055c87..6aeecd6c 100644
--- a/ballista-cli/src/main.rs
+++ b/ballista-cli/src/main.rs
@@ -36,7 +36,7 @@ struct Args {
         short = 'p',
         long,
         help = "Path to your data, default to current directory",
-        validator(is_valid_data_dir)
+        value_parser(parse_valid_data_dir)
     )]
     data_path: Option<String>,
 
@@ -44,14 +44,14 @@ struct Args {
         short = 'c',
         long,
         help = "The batch size of each query, or use Ballista default",
-        validator(is_valid_batch_size)
+        value_parser(parse_batch_size)
     )]
     batch_size: Option<usize>,
 
     #[clap(
         long,
         help = "The max concurrent tasks, only for Ballista local mode. 
Default: all available cores",
-        validator(is_valid_concurrent_tasks_size)
+        value_parser(parse_valid_concurrent_tasks_size)
     )]
     concurrent_tasks: Option<usize>,
 
@@ -60,7 +60,7 @@ struct Args {
         long,
         multiple_values = true,
         help = "Execute commands from file(s), then exit",
-        validator(is_valid_file)
+        value_parser(parse_valid_file)
     )]
     file: Vec<String>,
 
@@ -69,12 +69,12 @@ struct Args {
         long,
         multiple_values = true,
         help = "Run the provided files on startup instead of ~/.ballistarc",
-        validator(is_valid_file),
+        value_parser(parse_valid_file),
         conflicts_with = "file"
     )]
     rc: Option<Vec<String>>,
 
-    #[clap(long, arg_enum, default_value_t = PrintFormat::Table)]
+    #[clap(long, value_enum, default_value_t = PrintFormat::Table)]
     format: PrintFormat,
 
     #[clap(long, help = "Ballista scheduler host")]
@@ -168,32 +168,32 @@ pub async fn main() -> Result<()> {
     Ok(())
 }
 
-fn is_valid_file(dir: &str) -> std::result::Result<(), String> {
+fn parse_valid_file(dir: &str) -> std::result::Result<String, String> {
     if Path::new(dir).is_file() {
-        Ok(())
+        Ok(dir.to_string())
     } else {
         Err(format!("Invalid file '{dir}'"))
     }
 }
 
-fn is_valid_data_dir(dir: &str) -> std::result::Result<(), String> {
+fn parse_valid_data_dir(dir: &str) -> std::result::Result<String, String> {
     if Path::new(dir).is_dir() {
-        Ok(())
+        Ok(dir.to_string())
     } else {
         Err(format!("Invalid data directory '{dir}'"))
     }
 }
 
-fn is_valid_batch_size(size: &str) -> std::result::Result<(), String> {
+fn parse_batch_size(size: &str) -> std::result::Result<usize, String> {
     match size.parse::<usize>() {
-        Ok(size) if size > 0 => Ok(()),
+        Ok(size) if size > 0 => Ok(size),
         _ => Err(format!("Invalid batch size '{size}'")),
     }
 }
 
-fn is_valid_concurrent_tasks_size(size: &str) -> std::result::Result<(), 
String> {
+fn parse_valid_concurrent_tasks_size(size: &str) -> std::result::Result<usize, 
String> {
     match size.parse::<usize>() {
-        Ok(size) if size > 0 => Ok(()),
+        Ok(size) if size > 0 => Ok(size),
         _ => Err(format!("Invalid concurrent_tasks size '{size}'")),
     }
 }
diff --git a/ballista/client/README.md b/ballista/client/README.md
index 19dc1439..ac65bc98 100644
--- a/ballista/client/README.md
+++ b/ballista/client/README.md
@@ -92,7 +92,8 @@ data set. Download the file and add it to the `testdata` 
folder before running t
 
 ```rust,no_run
 use ballista::prelude::*;
-use datafusion::prelude::{col, min, max, avg, sum, ParquetReadOptions};
+use datafusion::prelude::{col, ParquetReadOptions};
+use datafusion::functions_aggregate::{min_max::min, min_max::max, sum::sum, 
average::avg};
 
 #[tokio::main]
 async fn main() -> Result<()> {
diff --git a/ballista/client/src/context.rs b/ballista/client/src/context.rs
index de22b777..269afc64 100644
--- a/ballista/client/src/context.rs
+++ b/ballista/client/src/context.rs
@@ -19,6 +19,7 @@
 
 use datafusion::arrow::datatypes::SchemaRef;
 use datafusion::execution::context::DataFilePaths;
+use datafusion::sql::TableReference;
 use log::info;
 use parking_lot::Mutex;
 use sqlparser::ast::Statement;
@@ -33,7 +34,6 @@ use ballista_core::utils::{
 };
 use datafusion_proto::protobuf::LogicalPlanNode;
 
-use datafusion::catalog::TableReference;
 use datafusion::dataframe::DataFrame;
 use datafusion::datasource::{source_as_provider, TableProvider};
 use datafusion::error::{DataFusionError, Result};
@@ -791,7 +791,7 @@ mod standalone_tests {
         let res = df.collect().await.unwrap();
         let expected = vec![
             "+--------------+",
-            "| MIN(test.id) |",
+            "| min(test.id) |",
             "+--------------+",
             "| 0            |",
             "+--------------+",
@@ -802,7 +802,7 @@ mod standalone_tests {
         let res = df.collect().await.unwrap();
         let expected = vec![
             "+--------------+",
-            "| MAX(test.id) |",
+            "| max(test.id) |",
             "+--------------+",
             "| 7            |",
             "+--------------+",
@@ -818,7 +818,7 @@ mod standalone_tests {
         let res = df.collect().await.unwrap();
         let expected = vec![
             "+--------------+",
-            "| SUM(test.id) |",
+            "| sum(test.id) |",
             "+--------------+",
             "| 28           |",
             "+--------------+",
@@ -833,7 +833,7 @@ mod standalone_tests {
         let res = df.collect().await.unwrap();
         let expected = vec![
             "+--------------+",
-            "| AVG(test.id) |",
+            "| avg(test.id) |",
             "+--------------+",
             "| 3.5          |",
             "+--------------+",
@@ -849,7 +849,7 @@ mod standalone_tests {
         let res = df.collect().await.unwrap();
         let expected = vec![
             "+----------------+",
-            "| COUNT(test.id) |",
+            "| count(test.id) |",
             "+----------------+",
             "| 8              |",
             "+----------------+",
@@ -867,7 +867,7 @@ mod standalone_tests {
         let res = df.collect().await.unwrap();
         let expected = vec![
             "+--------------------------+",
-            "| APPROX_DISTINCT(test.id) |",
+            "| approx_distinct(test.id) |",
             "+--------------------------+",
             "| 8                        |",
             "+--------------------------+",
@@ -885,7 +885,7 @@ mod standalone_tests {
         let res = df.collect().await.unwrap();
         let expected = vec![
             "+--------------------------+",
-            "| ARRAY_AGG(test.id)       |",
+            "| array_agg(test.id)       |",
             "+--------------------------+",
             "| [4, 5, 6, 7, 2, 3, 0, 1] |",
             "+--------------------------+",
@@ -914,7 +914,7 @@ mod standalone_tests {
         let res = df.collect().await.unwrap();
         let expected = vec![
             "+-------------------+",
-            "| VAR_POP(test.id)  |",
+            "| var_pop(test.id)  |",
             "+-------------------+",
             "| 5.250000000000001 |",
             "+-------------------+",
@@ -946,7 +946,7 @@ mod standalone_tests {
         let res = df.collect().await.unwrap();
         let expected = vec![
             "+--------------------+",
-            "| STDDEV(test.id)    |",
+            "| stddev(test.id)    |",
             "+--------------------+",
             "| 2.4494897427831783 |",
             "+--------------------+",
@@ -960,7 +960,7 @@ mod standalone_tests {
         let res = df.collect().await.unwrap();
         let expected = vec![
             "+--------------------+",
-            "| STDDEV(test.id)    |",
+            "| stddev(test.id)    |",
             "+--------------------+",
             "| 2.4494897427831783 |",
             "+--------------------+",
@@ -996,25 +996,27 @@ mod standalone_tests {
         let res = df.collect().await.unwrap();
         let expected = vec![
             "+--------------------------------+",
-            "| CORR(test.id,test.tinyint_col) |",
+            "| corr(test.id,test.tinyint_col) |",
             "+--------------------------------+",
             "| 0.21821789023599245            |",
             "+--------------------------------+",
         ];
         assert_result_eq(expected, &res);
     }
+    // enable when upgrading Datafusion to > 42
+    #[ignore]
     #[tokio::test]
     async fn test_aggregate_approx_percentile() {
         let context = create_test_context().await;
 
         let df = context
-            .sql("select approx_percentile_cont_with_weight(\"id\", 2, 0.5) 
from test")
+            .sql("select approx_percentile_cont_with_weight(id, 2, 0.5) from 
test")
             .await
             .unwrap();
         let res = df.collect().await.unwrap();
         let expected = vec![
             
"+-------------------------------------------------------------------+",
-            "| 
APPROX_PERCENTILE_CONT_WITH_WEIGHT(test.id,Int64(2),Float64(0.5)) |",
+            "| 
approx_percentile_cont_with_weight(test.id,Int64(2),Float64(0.5)) |",
             
"+-------------------------------------------------------------------+",
             "| 1                                                               
  |",
             
"+-------------------------------------------------------------------+",
@@ -1028,7 +1030,7 @@ mod standalone_tests {
         let res = df.collect().await.unwrap();
         let expected = vec![
             "+------------------------------------------------------+",
-            "| APPROX_PERCENTILE_CONT(test.double_col,Float64(0.5)) |",
+            "| approx_percentile_cont(test.double_col,Float64(0.5)) |",
             "+------------------------------------------------------+",
             "| 7.574999999999999                                    |",
             "+------------------------------------------------------+",
diff --git a/ballista/core/Cargo.toml b/ballista/core/Cargo.toml
index fccdd0ec..8a01f56f 100644
--- a/ballista/core/Cargo.toml
+++ b/ballista/core/Cargo.toml
@@ -51,7 +51,7 @@ async-trait = "0.1.41"
 ballista-cache = { path = "../cache", version = "0.12.0" }
 bytes = "1.0"
 chrono = { version = "0.4", default-features = false }
-clap = { version = "3", features = ["derive", "cargo"] }
+clap = { workspace = true }
 datafusion = { workspace = true }
 datafusion-objectstore-hdfs = { version = "0.1.4", default-features = false, 
optional = true }
 datafusion-proto = { workspace = true }
@@ -68,8 +68,8 @@ once_cell = "1.9.0"
 
 parking_lot = "0.12"
 parse_arg = "0.1.3"
-prost = "0.12"
-prost-types = "0.12"
+prost = { workspace = true }
+prost-types = { workspace = true }
 rand = "0.8"
 serde = { version = "1", features = ["derive"] }
 sqlparser = { workspace = true }
diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs
index 03c8f6b9..46424ecf 100644
--- a/ballista/core/src/config.rs
+++ b/ballista/core/src/config.rs
@@ -18,7 +18,7 @@
 
 //! Ballista configuration
 
-use clap::ArgEnum;
+use clap::ValueEnum;
 use core::fmt;
 use std::collections::HashMap;
 use std::result;
@@ -307,7 +307,7 @@ impl BallistaConfig {
 
 // an enum used to configure the scheduler policy
 // needs to be visible to code generated by configure_me
-#[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)]
+#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)]
 pub enum TaskSchedulingPolicy {
     PullStaged,
     PushStaged,
@@ -317,7 +317,7 @@ impl std::str::FromStr for TaskSchedulingPolicy {
     type Err = String;
 
     fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
-        ArgEnum::from_str(s, true)
+        ValueEnum::from_str(s, true)
     }
 }
 
@@ -329,7 +329,7 @@ impl parse_arg::ParseArgFromStr for TaskSchedulingPolicy {
 
 // an enum used to configure the log rolling policy
 // needs to be visible to code generated by configure_me
-#[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)]
+#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)]
 pub enum LogRotationPolicy {
     Minutely,
     Hourly,
@@ -341,7 +341,7 @@ impl std::str::FromStr for LogRotationPolicy {
     type Err = String;
 
     fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
-        ArgEnum::from_str(s, true)
+        ValueEnum::from_str(s, true)
     }
 }
 
@@ -353,7 +353,7 @@ impl parse_arg::ParseArgFromStr for LogRotationPolicy {
 
 // an enum used to configure the source data cache policy
 // needs to be visible to code generated by configure_me
-#[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)]
+#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)]
 pub enum DataCachePolicy {
     LocalDiskFile,
 }
@@ -362,7 +362,7 @@ impl std::str::FromStr for DataCachePolicy {
     type Err = String;
 
     fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
-        ArgEnum::from_str(s, true)
+        ValueEnum::from_str(s, true)
     }
 }
 
diff --git a/ballista/core/src/execution_plans/distributed_query.rs 
b/ballista/core/src/execution_plans/distributed_query.rs
index b96367bb..050ba877 100644
--- a/ballista/core/src/execution_plans/distributed_query.rs
+++ b/ballista/core/src/execution_plans/distributed_query.rs
@@ -154,6 +154,10 @@ impl<T: 'static + AsLogicalPlan> DisplayAs for 
DistributedQueryExec<T> {
 }
 
 impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
+    fn name(&self) -> &str {
+        "DistributedQueryExec"
+    }
+
     fn as_any(&self) -> &dyn Any {
         self
     }
diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs 
b/ballista/core/src/execution_plans/shuffle_reader.rs
index 79dfe296..2f856b39 100644
--- a/ballista/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/core/src/execution_plans/shuffle_reader.rs
@@ -107,6 +107,10 @@ impl DisplayAs for ShuffleReaderExec {
 }
 
 impl ExecutionPlan for ShuffleReaderExec {
+    fn name(&self) -> &str {
+        "ShuffleReaderExec"
+    }
+
     fn as_any(&self) -> &dyn Any {
         self
     }
diff --git a/ballista/core/src/execution_plans/shuffle_writer.rs 
b/ballista/core/src/execution_plans/shuffle_writer.rs
index 7f21b18b..87e4feea 100644
--- a/ballista/core/src/execution_plans/shuffle_writer.rs
+++ b/ballista/core/src/execution_plans/shuffle_writer.rs
@@ -355,6 +355,10 @@ impl DisplayAs for ShuffleWriterExec {
 }
 
 impl ExecutionPlan for ShuffleWriterExec {
+    fn name(&self) -> &str {
+        "ShuffleWriterExec"
+    }
+
     fn as_any(&self) -> &dyn Any {
         self
     }
diff --git a/ballista/core/src/execution_plans/unresolved_shuffle.rs 
b/ballista/core/src/execution_plans/unresolved_shuffle.rs
index b3c30c0d..e227e2ac 100644
--- a/ballista/core/src/execution_plans/unresolved_shuffle.rs
+++ b/ballista/core/src/execution_plans/unresolved_shuffle.rs
@@ -82,6 +82,10 @@ impl DisplayAs for UnresolvedShuffleExec {
 }
 
 impl ExecutionPlan for UnresolvedShuffleExec {
+    fn name(&self) -> &str {
+        "UnresolvedShuffleExec"
+    }
+
     fn as_any(&self) -> &dyn Any {
         self
     }
diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs
index 08208eed..2bb555d1 100644
--- a/ballista/core/src/serde/mod.rs
+++ b/ballista/core/src/serde/mod.rs
@@ -21,9 +21,13 @@
 use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction};
 
 use arrow_flight::sql::ProstMessageExt;
-use datafusion::common::DataFusionError;
+use datafusion::common::{DataFusionError, Result};
 use datafusion::execution::FunctionRegistry;
 use datafusion::physical_plan::{ExecutionPlan, Partitioning};
+use datafusion_proto::logical_plan::file_formats::{
+    ArrowLogicalExtensionCodec, AvroLogicalExtensionCodec, 
CsvLogicalExtensionCodec,
+    JsonLogicalExtensionCodec, ParquetLogicalExtensionCodec,
+};
 use 
datafusion_proto::physical_plan::from_proto::parse_protobuf_hash_partitioning;
 use datafusion_proto::protobuf::proto_error;
 use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
@@ -84,7 +88,7 @@ pub struct BallistaCodec<
 impl Default for BallistaCodec {
     fn default() -> Self {
         Self {
-            logical_extension_codec: Arc::new(DefaultLogicalExtensionCodec {}),
+            logical_extension_codec: 
Arc::new(BallistaLogicalExtensionCodec::default()),
             physical_extension_codec: Arc::new(BallistaPhysicalExtensionCodec 
{}),
             logical_plan_repr: PhantomData,
             physical_plan_repr: PhantomData,
@@ -114,6 +118,102 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> BallistaCodec<T,
     }
 }
 
+#[derive(Debug)]
+pub struct BallistaLogicalExtensionCodec {
+    default_codec: Arc<dyn LogicalExtensionCodec>,
+    file_format_codecs: Vec<Arc<dyn LogicalExtensionCodec>>,
+}
+
+impl BallistaLogicalExtensionCodec {
+    fn try_any<T>(
+        &self,
+        mut f: impl FnMut(&dyn LogicalExtensionCodec) -> Result<T>,
+    ) -> Result<T> {
+        let mut last_err = None;
+        for codec in &self.file_format_codecs {
+            match f(codec.as_ref()) {
+                Ok(node) => return Ok(node),
+                Err(err) => last_err = Some(err),
+            }
+        }
+
+        Err(last_err.unwrap_or_else(|| {
+            DataFusionError::NotImplemented("Empty list of composed 
codecs".to_owned())
+        }))
+    }
+}
+
+impl Default for BallistaLogicalExtensionCodec {
+    fn default() -> Self {
+        Self {
+            default_codec: Arc::new(DefaultLogicalExtensionCodec {}),
+            file_format_codecs: vec![
+                Arc::new(CsvLogicalExtensionCodec {}),
+                Arc::new(JsonLogicalExtensionCodec {}),
+                Arc::new(ParquetLogicalExtensionCodec {}),
+                Arc::new(ArrowLogicalExtensionCodec {}),
+                Arc::new(AvroLogicalExtensionCodec {}),
+            ],
+        }
+    }
+}
+
+impl LogicalExtensionCodec for BallistaLogicalExtensionCodec {
+    fn try_decode(
+        &self,
+        buf: &[u8],
+        inputs: &[datafusion::logical_expr::LogicalPlan],
+        ctx: &datafusion::prelude::SessionContext,
+    ) -> Result<datafusion::logical_expr::Extension> {
+        self.default_codec.try_decode(buf, inputs, ctx)
+    }
+
+    fn try_encode(
+        &self,
+        node: &datafusion::logical_expr::Extension,
+        buf: &mut Vec<u8>,
+    ) -> Result<()> {
+        self.default_codec.try_encode(node, buf)
+    }
+
+    fn try_decode_table_provider(
+        &self,
+        buf: &[u8],
+        table_ref: &datafusion::sql::TableReference,
+        schema: datafusion::arrow::datatypes::SchemaRef,
+        ctx: &datafusion::prelude::SessionContext,
+    ) -> Result<Arc<dyn datafusion::catalog::TableProvider>> {
+        self.default_codec
+            .try_decode_table_provider(buf, table_ref, schema, ctx)
+    }
+
+    fn try_encode_table_provider(
+        &self,
+        table_ref: &datafusion::sql::TableReference,
+        node: Arc<dyn datafusion::catalog::TableProvider>,
+        buf: &mut Vec<u8>,
+    ) -> Result<()> {
+        self.default_codec
+            .try_encode_table_provider(table_ref, node, buf)
+    }
+
+    fn try_decode_file_format(
+        &self,
+        buf: &[u8],
+        ctx: &datafusion::prelude::SessionContext,
+    ) -> Result<Arc<dyn 
datafusion::datasource::file_format::FileFormatFactory>> {
+        self.try_any(|codec| codec.try_decode_file_format(buf, ctx))
+    }
+
+    fn try_encode_file_format(
+        &self,
+        buf: &mut Vec<u8>,
+        node: Arc<dyn datafusion::datasource::file_format::FileFormatFactory>,
+    ) -> Result<()> {
+        self.try_any(|codec| codec.try_encode_file_format(buf, node.clone()))
+    }
+}
+
 #[derive(Debug)]
 pub struct BallistaPhysicalExtensionCodec {}
 
diff --git a/ballista/core/src/serde/scheduler/mod.rs 
b/ballista/core/src/serde/scheduler/mod.rs
index 0ced200e..23c9c425 100644
--- a/ballista/core/src/serde/scheduler/mod.rs
+++ b/ballista/core/src/serde/scheduler/mod.rs
@@ -25,6 +25,7 @@ use datafusion::arrow::array::{
 use datafusion::arrow::datatypes::{DataType, Field};
 use datafusion::common::DataFusionError;
 use datafusion::execution::FunctionRegistry;
+use datafusion::logical_expr::planner::ExprPlanner;
 use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
 use datafusion::physical_plan::ExecutionPlan;
 use datafusion::physical_plan::Partitioning;
@@ -299,6 +300,10 @@ pub struct SimpleFunctionRegistry {
 }
 
 impl FunctionRegistry for SimpleFunctionRegistry {
+    fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
+        vec![]
+    }
+
     fn udfs(&self) -> HashSet<String> {
         self.scalar_functions.keys().cloned().collect()
     }
diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs
index 45a4f53f..7e88ffaf 100644
--- a/ballista/core/src/utils.rs
+++ b/ballista/core/src/utils.rs
@@ -35,6 +35,7 @@ use datafusion::execution::context::{
     QueryPlanner, SessionConfig, SessionContext, SessionState,
 };
 use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
+use datafusion::execution::session_state::SessionStateBuilder;
 use datafusion::logical_expr::{DdlStatement, LogicalPlan};
 use datafusion::physical_plan::aggregates::AggregateExec;
 use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
@@ -62,13 +63,14 @@ use tonic::transport::{Channel, Error, Server};
 
 /// Default session builder using the provided configuration
 pub fn default_session_builder(config: SessionConfig) -> SessionState {
-    SessionState::new_with_config_rt(
-        config,
-        Arc::new(
+    SessionStateBuilder::new()
+        .with_default_features()
+        .with_config(config)
+        .with_runtime_env(Arc::new(
             
RuntimeEnv::new(with_object_store_registry(RuntimeConfig::default()))
                 .unwrap(),
-        ),
-    )
+        ))
+        .build()
 }
 
 /// Stream data to disk in Arrow IPC format
@@ -252,15 +254,16 @@ pub fn create_df_ctx_with_ballista_query_planner<T: 
'static + AsLogicalPlan>(
     let session_config = SessionConfig::new()
         .with_target_partitions(config.default_shuffle_partitions())
         .with_information_schema(true);
-    let mut session_state = SessionState::new_with_config_rt(
-        session_config,
-        Arc::new(
+    let session_state = SessionStateBuilder::new()
+        .with_default_features()
+        .with_config(session_config)
+        .with_runtime_env(Arc::new(
             
RuntimeEnv::new(with_object_store_registry(RuntimeConfig::default()))
                 .unwrap(),
-        ),
-    )
-    .with_query_planner(planner);
-    session_state = session_state.with_session_id(session_id);
+        ))
+        .with_query_planner(planner)
+        .with_session_id(session_id)
+        .build();
     // the SessionContext created here is the client side context, but the 
session_id is from server side.
     SessionContext::new_with_state(session_state)
 }
diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml
index 6bebaa87..e0ca6efb 100644
--- a/ballista/executor/Cargo.toml
+++ b/ballista/executor/Cargo.toml
@@ -48,7 +48,6 @@ dashmap = "5.4.0"
 datafusion = { workspace = true }
 datafusion-proto = { workspace = true }
 futures = "0.3"
-hyper = "0.14.4"
 log = "0.4"
 mimalloc = { version = "0.1", default-features = false, optional = true }
 num_cpus = "1.13.0"
diff --git a/ballista/executor/src/collect.rs b/ballista/executor/src/collect.rs
index eb96e314..1d77e719 100644
--- a/ballista/executor/src/collect.rs
+++ b/ballista/executor/src/collect.rs
@@ -67,6 +67,10 @@ impl DisplayAs for CollectExec {
 }
 
 impl ExecutionPlan for CollectExec {
+    fn name(&self) -> &str {
+        "CollectExec"
+    }
+
     fn as_any(&self) -> &dyn Any {
         self
     }
diff --git a/ballista/executor/src/executor.rs 
b/ballista/executor/src/executor.rs
index ccc7f273..4e83b125 100644
--- a/ballista/executor/src/executor.rs
+++ b/ballista/executor/src/executor.rs
@@ -28,6 +28,8 @@ use ballista_core::serde::scheduler::PartitionId;
 use dashmap::DashMap;
 use datafusion::execution::context::TaskContext;
 use datafusion::execution::runtime_env::RuntimeEnv;
+use datafusion::functions::all_default_functions;
+use datafusion::functions_aggregate::all_default_aggregate_functions;
 use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
 use futures::future::AbortHandle;
 use std::collections::HashMap;
@@ -103,12 +105,22 @@ impl Executor {
         concurrent_tasks: usize,
         execution_engine: Option<Arc<dyn ExecutionEngine>>,
     ) -> Self {
+        let scalar_functions = all_default_functions()
+            .into_iter()
+            .map(|f| (f.name().to_string(), f))
+            .collect();
+
+        let aggregate_functions = all_default_aggregate_functions()
+            .into_iter()
+            .map(|f| (f.name().to_string(), f))
+            .collect();
+
         Self {
             metadata,
             work_dir: work_dir.to_owned(),
-            // TODO add logic to dynamically load UDF/UDAFs libs from files
-            scalar_functions: HashMap::new(),
-            aggregate_functions: HashMap::new(),
+            scalar_functions,
+            aggregate_functions,
+            // TODO: set to default window functions when they are moved to 
udwf
             window_functions: HashMap::new(),
             runtime,
             runtime_with_data_cache,
@@ -277,6 +289,10 @@ mod test {
     }
 
     impl ExecutionPlan for NeverendingOperator {
+        fn name(&self) -> &str {
+            "NeverendingOperator"
+        }
+
         fn as_any(&self) -> &dyn Any {
             self
         }
diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml
index dd878b9a..596db6d1 100644
--- a/ballista/scheduler/Cargo.toml
+++ b/ballista/scheduler/Cargo.toml
@@ -46,20 +46,19 @@ anyhow = "1"
 arrow-flight = { workspace = true }
 async-recursion = "1.0.0"
 async-trait = "0.1.41"
+axum = "0.6.20"
 ballista-core = { path = "../core", version = "0.12.0", features = ["s3"] }
 base64 = { version = "0.21" }
-clap = { version = "3", features = ["derive", "cargo"] }
+clap = { workspace = true }
 configure_me = { workspace = true }
 dashmap = "5.4.0"
 datafusion = { workspace = true }
 datafusion-proto = { workspace = true }
-etcd-client = { version = "0.12", optional = true }
+etcd-client = { version = "0.14", optional = true }
 flatbuffers = { version = "23.5.26" }
 futures = "0.3"
 graphviz-rust = "0.8.0"
-http = "0.2"
-http-body = "0.4"
-hyper = "0.14.4"
+http = "0.2.9"
 itertools = "0.12.0"
 log = "0.4"
 object_store = { workspace = true }
@@ -67,20 +66,20 @@ once_cell = { version = "1.16.0", optional = true }
 parking_lot = "0.12"
 parse_arg = "0.1.3"
 prometheus = { version = "0.13", features = ["process"], optional = true }
-prost = "0.12"
-prost-types = { version = "0.12.0" }
+prost = { workspace = true }
+prost-types = { workspace = true }
 rand = "0.8"
 serde = { version = "1", features = ["derive"] }
 sled_package = { package = "sled", version = "0.34", optional = true }
 tokio = { version = "1.0", features = ["full"] }
 tokio-stream = { version = "0.1", features = ["net"], optional = true }
 tonic = { workspace = true }
-tower = { version = "0.4" }
+# tonic 0.12.2 depends on tower 0.4.7
+tower = { version = "0.4.7", default-features = false, features = ["make", 
"util"] }
 tracing = { workspace = true }
 tracing-appender = { workspace = true }
 tracing-subscriber = { workspace = true }
 uuid = { version = "1.0", features = ["v4"] }
-warp = "0.3"
 
 [dev-dependencies]
 ballista-core = { path = "../core", version = "0.12.0" }
diff --git a/ballista/scheduler/src/api/handlers.rs 
b/ballista/scheduler/src/api/handlers.rs
index 463ca217..4d0366ff 100644
--- a/ballista/scheduler/src/api/handlers.rs
+++ b/ballista/scheduler/src/api/handlers.rs
@@ -14,6 +14,11 @@ use crate::scheduler_server::event::QueryStageSchedulerEvent;
 use crate::scheduler_server::SchedulerServer;
 use crate::state::execution_graph::ExecutionStage;
 use crate::state::execution_graph_dot::ExecutionGraphDot;
+use axum::{
+    extract::{Path, State},
+    response::{IntoResponse, Response},
+    Json,
+};
 use ballista_core::serde::protobuf::job_status::Status;
 use ballista_core::BALLISTA_VERSION;
 use datafusion::physical_plan::metrics::{MetricValue, MetricsSet, Time};
@@ -22,10 +27,9 @@ use datafusion_proto::physical_plan::AsExecutionPlan;
 use graphviz_rust::cmd::{CommandArg, Format};
 use graphviz_rust::exec;
 use graphviz_rust::printer::PrinterContext;
-use http::header::CONTENT_TYPE;
-
+use http::{header::CONTENT_TYPE, StatusCode};
+use std::sync::Arc;
 use std::time::Duration;
-use warp::Rejection;
 
 #[derive(Debug, serde::Serialize)]
 struct SchedulerStateResponse {
@@ -64,22 +68,26 @@ pub struct QueryStageSummary {
     pub elapsed_compute: String,
 }
 
-/// Return current scheduler state
-pub(crate) async fn get_scheduler_state<T: AsLogicalPlan, U: AsExecutionPlan>(
-    data_server: SchedulerServer<T, U>,
-) -> Result<impl warp::Reply, Rejection> {
+pub async fn get_scheduler_state<
+    T: AsLogicalPlan + Clone + Send + Sync + 'static,
+    U: AsExecutionPlan + Send + Sync + 'static,
+>(
+    State(data_server): State<Arc<SchedulerServer<T, U>>>,
+) -> impl IntoResponse {
     let response = SchedulerStateResponse {
         started: data_server.start_time,
         version: BALLISTA_VERSION,
     };
-    Ok(warp::reply::json(&response))
+    Json(response)
 }
 
-/// Return list of executors
-pub(crate) async fn get_executors<T: AsLogicalPlan, U: AsExecutionPlan>(
-    data_server: SchedulerServer<T, U>,
-) -> Result<impl warp::Reply, Rejection> {
-    let state = data_server.state;
+pub async fn get_executors<
+    T: AsLogicalPlan + Clone + Send + Sync + 'static,
+    U: AsExecutionPlan + Send + Sync + 'static,
+>(
+    State(data_server): State<Arc<SchedulerServer<T, U>>>,
+) -> impl IntoResponse {
+    let state = &data_server.state;
     let executors: Vec<ExecutorMetaResponse> = state
         .executor_manager
         .get_executor_state()
@@ -94,21 +102,23 @@ pub(crate) async fn get_executors<T: AsLogicalPlan, U: 
AsExecutionPlan>(
         })
         .collect();
 
-    Ok(warp::reply::json(&executors))
+    Json(executors)
 }
 
-/// Return list of jobs
-pub(crate) async fn get_jobs<T: AsLogicalPlan, U: AsExecutionPlan>(
-    data_server: SchedulerServer<T, U>,
-) -> Result<impl warp::Reply, Rejection> {
+pub async fn get_jobs<
+    T: AsLogicalPlan + Clone + Send + Sync + 'static,
+    U: AsExecutionPlan + Send + Sync + 'static,
+>(
+    State(data_server): State<Arc<SchedulerServer<T, U>>>,
+) -> Result<impl IntoResponse, StatusCode> {
     // TODO: Display last seen information in UI
-    let state = data_server.state;
+    let state = &data_server.state;
 
     let jobs = state
         .task_manager
         .get_jobs()
         .await
-        .map_err(|_| warp::reject())?;
+        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
 
     let jobs: Vec<JobResponse> = jobs
         .iter()
@@ -157,31 +167,34 @@ pub(crate) async fn get_jobs<T: AsLogicalPlan, U: 
AsExecutionPlan>(
         })
         .collect();
 
-    Ok(warp::reply::json(&jobs))
+    Ok(Json(jobs))
 }
 
-pub(crate) async fn cancel_job<T: AsLogicalPlan, U: AsExecutionPlan>(
-    data_server: SchedulerServer<T, U>,
-    job_id: String,
-) -> Result<impl warp::Reply, Rejection> {
+pub async fn cancel_job<
+    T: AsLogicalPlan + Clone + Send + Sync + 'static,
+    U: AsExecutionPlan + Send + Sync + 'static,
+>(
+    State(data_server): State<Arc<SchedulerServer<T, U>>>,
+    Path(job_id): Path<String>,
+) -> Result<impl IntoResponse, StatusCode> {
     // 404 if job doesn't exist
     data_server
         .state
         .task_manager
         .get_job_status(&job_id)
         .await
-        .map_err(|_| warp::reject())?
-        .ok_or_else(warp::reject)?;
+        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
+        .ok_or(StatusCode::NOT_FOUND)?;
 
     data_server
         .query_stage_event_loop
         .get_sender()
-        .map_err(|_| warp::reject())?
+        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
         .post_event(QueryStageSchedulerEvent::JobCancel(job_id))
         .await
-        .map_err(|_| warp::reject())?;
+        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
 
-    Ok(warp::reply::json(&CancelJobResponse { cancelled: true }))
+    Ok(Json(CancelJobResponse { cancelled: true }))
 }
 
 #[derive(Debug, serde::Serialize)]
@@ -189,69 +202,71 @@ pub struct QueryStagesResponse {
     pub stages: Vec<QueryStageSummary>,
 }
 
-/// Get the execution graph for the specified job id
-pub(crate) async fn get_query_stages<T: AsLogicalPlan, U: AsExecutionPlan>(
-    data_server: SchedulerServer<T, U>,
-    job_id: String,
-) -> Result<impl warp::Reply, Rejection> {
+pub async fn get_query_stages<
+    T: AsLogicalPlan + Clone + Send + Sync + 'static,
+    U: AsExecutionPlan + Send + Sync + 'static,
+>(
+    State(data_server): State<Arc<SchedulerServer<T, U>>>,
+    Path(job_id): Path<String>,
+) -> Result<impl IntoResponse, StatusCode> {
     if let Some(graph) = data_server
         .state
         .task_manager
         .get_job_execution_graph(&job_id)
         .await
-        .map_err(|_| warp::reject())?
+        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
     {
-        Ok(warp::reply::json(&QueryStagesResponse {
-            stages: graph
-                .as_ref()
-                .stages()
-                .iter()
-                .map(|(id, stage)| {
-                    let mut summary = QueryStageSummary {
-                        stage_id: id.to_string(),
-                        stage_status: stage.variant_name().to_string(),
-                        input_rows: 0,
-                        output_rows: 0,
-                        elapsed_compute: "".to_string(),
-                    };
-                    match stage {
-                        ExecutionStage::Running(running_stage) => {
-                            summary.input_rows = running_stage
-                                .stage_metrics
-                                .as_ref()
-                                .map(|m| get_combined_count(m.as_slice(), 
"input_rows"))
-                                .unwrap_or(0);
-                            summary.output_rows = running_stage
-                                .stage_metrics
-                                .as_ref()
-                                .map(|m| get_combined_count(m.as_slice(), 
"output_rows"))
-                                .unwrap_or(0);
-                            summary.elapsed_compute = running_stage
-                                .stage_metrics
-                                .as_ref()
-                                .map(|m| 
get_elapsed_compute_nanos(m.as_slice()))
-                                .unwrap_or_default();
-                        }
-                        ExecutionStage::Successful(completed_stage) => {
-                            summary.input_rows = get_combined_count(
-                                &completed_stage.stage_metrics,
-                                "input_rows",
-                            );
-                            summary.output_rows = get_combined_count(
-                                &completed_stage.stage_metrics,
-                                "output_rows",
-                            );
-                            summary.elapsed_compute =
-                                
get_elapsed_compute_nanos(&completed_stage.stage_metrics);
-                        }
-                        _ => {}
+        let stages = graph
+            .as_ref()
+            .stages()
+            .iter()
+            .map(|(id, stage)| {
+                let mut summary = QueryStageSummary {
+                    stage_id: id.to_string(),
+                    stage_status: stage.variant_name().to_string(),
+                    input_rows: 0,
+                    output_rows: 0,
+                    elapsed_compute: "".to_string(),
+                };
+                match stage {
+                    ExecutionStage::Running(running_stage) => {
+                        summary.input_rows = running_stage
+                            .stage_metrics
+                            .as_ref()
+                            .map(|m| get_combined_count(m.as_slice(), 
"input_rows"))
+                            .unwrap_or(0);
+                        summary.output_rows = running_stage
+                            .stage_metrics
+                            .as_ref()
+                            .map(|m| get_combined_count(m.as_slice(), 
"output_rows"))
+                            .unwrap_or(0);
+                        summary.elapsed_compute = running_stage
+                            .stage_metrics
+                            .as_ref()
+                            .map(|m| get_elapsed_compute_nanos(m.as_slice()))
+                            .unwrap_or_default();
+                    }
+                    ExecutionStage::Successful(completed_stage) => {
+                        summary.input_rows = get_combined_count(
+                            &completed_stage.stage_metrics,
+                            "input_rows",
+                        );
+                        summary.output_rows = get_combined_count(
+                            &completed_stage.stage_metrics,
+                            "output_rows",
+                        );
+                        summary.elapsed_compute =
+                            
get_elapsed_compute_nanos(&completed_stage.stage_metrics);
                     }
-                    summary
-                })
-                .collect(),
-        }))
+                    _ => {}
+                }
+                summary
+            })
+            .collect();
+
+        Ok(Json(QueryStagesResponse { stages }))
     } else {
-        Ok(warp::reply::json(&QueryStagesResponse { stages: vec![] }))
+        Ok(Json(QueryStagesResponse { stages: vec![] }))
     }
 }
 
@@ -286,78 +301,96 @@ fn get_combined_count(metrics: &[MetricsSet], name: &str) 
-> usize {
         .sum()
 }
 
-/// Generate a dot graph for the specified job id and return as plain text
-pub(crate) async fn get_job_dot_graph<T: AsLogicalPlan, U: AsExecutionPlan>(
-    data_server: SchedulerServer<T, U>,
-    job_id: String,
-) -> Result<String, Rejection> {
+pub async fn get_job_dot_graph<
+    T: AsLogicalPlan + Clone + Send + Sync + 'static,
+    U: AsExecutionPlan + Send + Sync + 'static,
+>(
+    State(data_server): State<Arc<SchedulerServer<T, U>>>,
+    Path(job_id): Path<String>,
+) -> Result<String, StatusCode> {
     if let Some(graph) = data_server
         .state
         .task_manager
         .get_job_execution_graph(&job_id)
         .await
-        .map_err(|_| warp::reject())?
+        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
     {
-        ExecutionGraphDot::generate(graph.as_ref()).map_err(|_| warp::reject())
+        ExecutionGraphDot::generate(graph.as_ref())
+            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
     } else {
         Ok("Not Found".to_string())
     }
 }
 
-/// Generate a dot graph for the specified job id and query stage and return 
as plain text
-pub(crate) async fn get_query_stage_dot_graph<T: AsLogicalPlan, U: 
AsExecutionPlan>(
-    data_server: SchedulerServer<T, U>,
-    job_id: String,
-    stage_id: usize,
-) -> Result<String, Rejection> {
+pub async fn get_query_stage_dot_graph<
+    T: AsLogicalPlan + Clone + Send + Sync + 'static,
+    U: AsExecutionPlan + Send + Sync + 'static,
+>(
+    State(data_server): State<Arc<SchedulerServer<T, U>>>,
+    Path((job_id, stage_id)): Path<(String, usize)>,
+) -> Result<impl IntoResponse, StatusCode> {
     if let Some(graph) = data_server
         .state
         .task_manager
         .get_job_execution_graph(&job_id)
         .await
-        .map_err(|_| warp::reject())?
+        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
     {
         ExecutionGraphDot::generate_for_query_stage(graph.as_ref(), stage_id)
-            .map_err(|_| warp::reject())
+            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
     } else {
         Ok("Not Found".to_string())
     }
 }
 
-/// Generate an SVG graph for the specified job id and return it as plain text
-pub(crate) async fn get_job_svg_graph<T: AsLogicalPlan, U: AsExecutionPlan>(
-    data_server: SchedulerServer<T, U>,
-    job_id: String,
-) -> Result<String, Rejection> {
-    let dot = get_job_dot_graph(data_server, job_id).await;
-    match dot {
-        Ok(dot) => {
-            let graph = graphviz_rust::parse(&dot);
-            if let Ok(graph) = graph {
-                exec(
-                    graph,
-                    &mut PrinterContext::default(),
-                    vec![CommandArg::Format(Format::Svg)],
-                )
-                .map(|bytes| String::from_utf8_lossy(&bytes).to_string())
-                .map_err(|_| warp::reject())
-            } else {
-                Ok("Cannot parse graph".to_string())
-            }
+pub async fn get_job_svg_graph<
+    T: AsLogicalPlan + Clone + Send + Sync + 'static,
+    U: AsExecutionPlan + Send + Sync + 'static,
+>(
+    State(data_server): State<Arc<SchedulerServer<T, U>>>,
+    Path(job_id): Path<String>,
+) -> Result<impl IntoResponse, StatusCode> {
+    let dot = get_job_dot_graph(State(data_server.clone()), 
Path(job_id)).await?;
+    match graphviz_rust::parse(&dot) {
+        Ok(graph) => {
+            let result = exec(
+                graph,
+                &mut PrinterContext::default(),
+                vec![CommandArg::Format(Format::Svg)],
+            )
+            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
+
+            let svg = String::from_utf8_lossy(&result).to_string();
+            Ok(Response::builder()
+                .header(CONTENT_TYPE, "image/svg+xml")
+                .body(svg)
+                .unwrap())
         }
-        _ => Ok("Not Found".to_string()),
+        Err(_) => Ok(Response::builder()
+            .status(StatusCode::BAD_REQUEST)
+            .body("Cannot parse graph".to_string())
+            .unwrap()),
     }
 }
 
-pub(crate) async fn get_scheduler_metrics<T: AsLogicalPlan, U: 
AsExecutionPlan>(
-    data_server: SchedulerServer<T, U>,
-) -> Result<impl warp::Reply, Rejection> {
-    Ok(data_server
-        .metrics_collector()
-        .gather_metrics()
-        .map_err(|_| warp::reject())?
-        .map(|(data, content_type)| {
-            warp::reply::with_header(data, CONTENT_TYPE, content_type)
-        })
-        .unwrap_or_else(|| warp::reply::with_header(vec![], CONTENT_TYPE, 
"text/html")))
+pub async fn get_scheduler_metrics<
+    T: AsLogicalPlan + Clone + Send + Sync + 'static,
+    U: AsExecutionPlan + Send + Sync + 'static,
+>(
+    State(data_server): State<Arc<SchedulerServer<T, U>>>,
+) -> impl IntoResponse {
+    match data_server.metrics_collector().gather_metrics() {
+        Ok(Some((data, content_type))) => Response::builder()
+            .header(CONTENT_TYPE, content_type)
+            .body(axum::body::Body::from(data))
+            .unwrap(),
+        Ok(None) => Response::builder()
+            .status(StatusCode::NO_CONTENT)
+            .body(axum::body::Body::empty())
+            .unwrap(),
+        Err(_) => Response::builder()
+            .status(StatusCode::INTERNAL_SERVER_ERROR)
+            .body(axum::body::Body::empty())
+            .unwrap(),
+    }
 }
diff --git a/ballista/scheduler/src/api/mod.rs 
b/ballista/scheduler/src/api/mod.rs
index 8f5555d0..c33d5157 100644
--- a/ballista/scheduler/src/api/mod.rs
+++ b/ballista/scheduler/src/api/mod.rs
@@ -13,126 +13,39 @@
 mod handlers;
 
 use crate::scheduler_server::SchedulerServer;
-use anyhow::Result;
+use axum::routing::patch;
+use axum::{routing::get, Router};
 use datafusion_proto::logical_plan::AsLogicalPlan;
 use datafusion_proto::physical_plan::AsExecutionPlan;
-use std::{
-    pin::Pin,
-    task::{Context as TaskContext, Poll},
-};
-use warp::filters::BoxedFilter;
-use warp::{Buf, Filter, Reply};
-
-pub enum EitherBody<A, B> {
-    Left(A),
-    Right(B),
-}
-
-pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
-pub type HttpBody = dyn http_body::Body<Data = dyn Buf, Error = Error> + 
'static;
-
-impl<A, B> http_body::Body for EitherBody<A, B>
-where
-    A: http_body::Body + Send + Unpin,
-    B: http_body::Body<Data = A::Data> + Send + Unpin,
-    A::Error: Into<Error>,
-    B::Error: Into<Error>,
-{
-    type Data = A::Data;
-    type Error = Error;
-
-    fn poll_data(
-        self: Pin<&mut Self>,
-        cx: &mut TaskContext<'_>,
-    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
-        match self.get_mut() {
-            EitherBody::Left(b) => 
Pin::new(b).poll_data(cx).map(map_option_err),
-            EitherBody::Right(b) => 
Pin::new(b).poll_data(cx).map(map_option_err),
-        }
-    }
-
-    fn poll_trailers(
-        self: Pin<&mut Self>,
-        cx: &mut TaskContext<'_>,
-    ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
-        match self.get_mut() {
-            EitherBody::Left(b) => 
Pin::new(b).poll_trailers(cx).map_err(Into::into),
-            EitherBody::Right(b) => 
Pin::new(b).poll_trailers(cx).map_err(Into::into),
-        }
-    }
-
-    fn is_end_stream(&self) -> bool {
-        match self {
-            EitherBody::Left(b) => b.is_end_stream(),
-            EitherBody::Right(b) => b.is_end_stream(),
-        }
-    }
-}
-
-fn map_option_err<T, U: Into<Error>>(
-    err: Option<Result<T, U>>,
-) -> Option<Result<T, Error>> {
-    err.map(|e| e.map_err(Into::into))
-}
-
-fn with_data_server<T: AsLogicalPlan + Clone, U: 'static + AsExecutionPlan>(
-    db: SchedulerServer<T, U>,
-) -> impl Filter<Extract = (SchedulerServer<T, U>,), Error = 
std::convert::Infallible> + Clone
-{
-    warp::any().map(move || db.clone())
-}
-
-pub fn get_routes<T: AsLogicalPlan + Clone, U: 'static + AsExecutionPlan>(
-    scheduler_server: SchedulerServer<T, U>,
-) -> BoxedFilter<(impl Reply,)> {
-    let route_scheduler_state = warp::path!("api" / "state")
-        .and(with_data_server(scheduler_server.clone()))
-        .and_then(handlers::get_scheduler_state);
-
-    let route_executors = warp::path!("api" / "executors")
-        .and(with_data_server(scheduler_server.clone()))
-        .and_then(handlers::get_executors);
-
-    let route_jobs = warp::path!("api" / "jobs")
-        .and(with_data_server(scheduler_server.clone()))
-        .and_then(|data_server| handlers::get_jobs(data_server));
-
-    let route_cancel_job = warp::path!("api" / "job" / String)
-        .and(warp::patch())
-        .and(with_data_server(scheduler_server.clone()))
-        .and_then(|job_id, data_server| handlers::cancel_job(data_server, 
job_id));
-
-    let route_query_stages = warp::path!("api" / "job" / String / "stages")
-        .and(with_data_server(scheduler_server.clone()))
-        .and_then(|job_id, data_server| 
handlers::get_query_stages(data_server, job_id));
-
-    let route_job_dot = warp::path!("api" / "job" / String / "dot")
-        .and(with_data_server(scheduler_server.clone()))
-        .and_then(|job_id, data_server| 
handlers::get_job_dot_graph(data_server, job_id));
-
-    let route_query_stage_dot =
-        warp::path!("api" / "job" / String / "stage" / usize / "dot")
-            .and(with_data_server(scheduler_server.clone()))
-            .and_then(|job_id, stage_id, data_server| {
-                handlers::get_query_stage_dot_graph(data_server, job_id, 
stage_id)
-            });
-
-    let route_job_dot_svg = warp::path!("api" / "job" / String / "dot_svg")
-        .and(with_data_server(scheduler_server.clone()))
-        .and_then(|job_id, data_server| 
handlers::get_job_svg_graph(data_server, job_id));
-
-    let route_scheduler_metrics = warp::path!("api" / "metrics")
-        .and(with_data_server(scheduler_server))
-        .and_then(|data_server| handlers::get_scheduler_metrics(data_server));
-
-    let routes = route_scheduler_state
-        .or(route_executors)
-        .or(route_jobs)
-        .or(route_cancel_job)
-        .or(route_query_stages)
-        .or(route_job_dot)
-        .or(route_query_stage_dot)
-        .or(route_job_dot_svg)
-        .or(route_scheduler_metrics);
-    routes.boxed()
+use std::sync::Arc;
+
+pub fn get_routes<
+    T: AsLogicalPlan + Clone + Send + Sync + 'static,
+    U: AsExecutionPlan + Send + Sync + 'static,
+>(
+    scheduler_server: Arc<SchedulerServer<T, U>>,
+) -> Router {
+    Router::new()
+        .route("/api/state", get(handlers::get_scheduler_state::<T, U>))
+        .route("/api/executors", get(handlers::get_executors::<T, U>))
+        .route("/api/jobs", get(handlers::get_jobs::<T, U>))
+        .route("/api/job/:job_id", patch(handlers::cancel_job::<T, U>))
+        .route(
+            "/api/job/:job_id/stages",
+            get(handlers::get_query_stages::<T, U>),
+        )
+        .route(
+            "/api/job/:job_id/dot",
+            get(handlers::get_job_dot_graph::<T, U>),
+        )
+        .route(
+            "/api/job/:job_id/stage/:stage_id/dot",
+            get(handlers::get_query_stage_dot_graph::<T, U>),
+        )
+        .route(
+            "/api/job/:job_id/dot_svg",
+            get(handlers::get_job_svg_graph::<T, U>),
+        )
+        .route("/api/metrics", get(handlers::get_scheduler_metrics::<T, U>))
+        .with_state(scheduler_server)
 }
diff --git a/ballista/scheduler/src/bin/main.rs 
b/ballista/scheduler/src/bin/main.rs
index ee9364c7..d2e2c9ce 100644
--- a/ballista/scheduler/src/bin/main.rs
+++ b/ballista/scheduler/src/bin/main.rs
@@ -47,8 +47,17 @@ mod config {
     ));
 }
 
-#[tokio::main]
-async fn main() -> Result<()> {
+fn main() -> Result<()> {
+    let runtime = tokio::runtime::Builder::new_multi_thread()
+        .enable_io()
+        .enable_time()
+        .thread_stack_size(32 * 1024 * 1024) // 32MB
+        .build()
+        .unwrap();
+
+    runtime.block_on(inner())
+}
+async fn inner() -> Result<()> {
     // parse options
     let (opt, _remaining_args) =
         
Config::including_optional_config_files(&["/etc/ballista/scheduler.toml"])
diff --git a/ballista/scheduler/src/cluster/mod.rs 
b/ballista/scheduler/src/cluster/mod.rs
index 8313f033..b7489a25 100644
--- a/ballista/scheduler/src/cluster/mod.rs
+++ b/ballista/scheduler/src/cluster/mod.rs
@@ -20,7 +20,7 @@ use std::fmt;
 use std::pin::Pin;
 use std::sync::Arc;
 
-use clap::ArgEnum;
+use clap::ValueEnum;
 use datafusion::common::tree_node::TreeNode;
 use datafusion::common::tree_node::TreeNodeRecursion;
 use datafusion::datasource::listing::PartitionedFile;
@@ -65,7 +65,7 @@ pub mod test_util;
 
 // an enum used to configure the backend
 // needs to be visible to code generated by configure_me
-#[derive(Debug, Clone, ArgEnum, serde::Deserialize, PartialEq, Eq)]
+#[derive(Debug, Clone, ValueEnum, serde::Deserialize, PartialEq, Eq)]
 pub enum ClusterStorage {
     Etcd,
     Memory,
@@ -76,7 +76,7 @@ impl std::str::FromStr for ClusterStorage {
     type Err = String;
 
     fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
-        ArgEnum::from_str(s, true)
+        ValueEnum::from_str(s, true)
     }
 }
 
@@ -764,7 +764,10 @@ mod test {
     };
     use crate::state::execution_graph::ExecutionGraph;
     use crate::state::task_manager::JobInfoCache;
-    use crate::test_utils::{mock_completed_task, 
test_aggregation_plan_with_job_id};
+    use crate::test_utils::{
+        mock_completed_task, revive_graph_and_complete_next_stage,
+        test_aggregation_plan_with_job_id,
+    };
 
     #[tokio::test]
     async fn test_bind_task_bias() -> Result<()> {
@@ -1008,10 +1011,11 @@ mod test {
 
     async fn mock_graph(
         job_id: &str,
-        num_partition: usize,
+        num_target_partitions: usize,
         num_pending_task: usize,
     ) -> Result<ExecutionGraph> {
-        let mut graph = test_aggregation_plan_with_job_id(num_partition, 
job_id).await;
+        let mut graph =
+            test_aggregation_plan_with_job_id(num_target_partitions, 
job_id).await;
         let executor = ExecutorMetadata {
             id: "executor_0".to_string(),
             host: "localhost".to_string(),
@@ -1020,14 +1024,10 @@ mod test {
             specification: ExecutorSpecification { task_slots: 32 },
         };
 
-        if let Some(task) = graph.pop_next_task(&executor.id)? {
-            let task_status = mock_completed_task(task, &executor.id);
-            graph.update_task_status(&executor, vec![task_status], 1, 1)?;
-        }
-
-        graph.revive();
+        // complete first stage
+        revive_graph_and_complete_next_stage(&mut graph)?;
 
-        for _i in 0..num_partition - num_pending_task {
+        for _ in 0..num_target_partitions - num_pending_task {
             if let Some(task) = graph.pop_next_task(&executor.id)? {
                 let task_status = mock_completed_task(task, &executor.id);
                 graph.update_task_status(&executor, vec![task_status], 1, 1)?;
diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs
index d15e928c..82280911 100644
--- a/ballista/scheduler/src/config.rs
+++ b/ballista/scheduler/src/config.rs
@@ -19,7 +19,7 @@
 //! Ballista scheduler specific configuration
 
 use ballista_core::config::TaskSchedulingPolicy;
-use clap::ArgEnum;
+use clap::ValueEnum;
 use std::fmt;
 
 /// Configurations for the ballista scheduler of scheduling jobs and tasks
@@ -189,7 +189,7 @@ pub enum ClusterStorageConfig {
 /// Policy of distributing tasks to available executor slots
 ///
 /// It needs to be visible to code generated by configure_me
-#[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)]
+#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)]
 pub enum TaskDistribution {
     /// Eagerly assign tasks to executor slots. This will assign as many task 
slots per executor
     /// as are currently available
@@ -208,7 +208,7 @@ impl std::str::FromStr for TaskDistribution {
     type Err = String;
 
     fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
-        ArgEnum::from_str(s, true)
+        ValueEnum::from_str(s, true)
     }
 }
 
diff --git a/ballista/scheduler/src/planner.rs 
b/ballista/scheduler/src/planner.rs
index 3da9f339..0e18a062 100644
--- a/ballista/scheduler/src/planner.rs
+++ b/ballista/scheduler/src/planner.rs
@@ -592,6 +592,8 @@ order by
         Ok(())
     }
 
+    #[ignore]
+    // enable when upgrading Datafusion, a bug is fixed with 
https://github.com/apache/datafusion/pull/11926/
     #[tokio::test]
     async fn roundtrip_serde_aggregate() -> Result<(), BallistaError> {
         let ctx = datafusion_test_context("testdata").await?;
diff --git a/ballista/scheduler/src/scheduler_process.rs 
b/ballista/scheduler/src/scheduler_process.rs
index 6bcaaec5..1f7f7ac3 100644
--- a/ballista/scheduler/src/scheduler_process.rs
+++ b/ballista/scheduler/src/scheduler_process.rs
@@ -15,26 +15,18 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use anyhow::{Context, Result};
+use anyhow::{Error, Result};
 #[cfg(feature = "flight-sql")]
 use arrow_flight::flight_service_server::FlightServiceServer;
-use futures::future::{self, Either, TryFutureExt};
-use hyper::{server::conn::AddrStream, service::make_service_fn, Server};
-use log::info;
-use std::convert::Infallible;
-use std::net::SocketAddr;
-use std::sync::Arc;
-use tonic::transport::server::Connected;
-use tower::Service;
-
-use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
-
 use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer;
 use ballista_core::serde::BallistaCodec;
 use ballista_core::utils::create_grpc_server;
 use ballista_core::BALLISTA_VERSION;
+use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
+use log::info;
+use std::{net::SocketAddr, sync::Arc};
 
-use crate::api::{get_routes, EitherBody, Error};
+use crate::api::get_routes;
 use crate::cluster::BallistaCluster;
 use crate::config::SchedulerConfig;
 use crate::flight_sql::FlightSqlServiceImpl;
@@ -70,58 +62,31 @@ pub async fn start_server(
 
     scheduler_server.init().await?;
 
-    Server::bind(&addr)
-        .serve(make_service_fn(move |request: &AddrStream| {
-            let config = &scheduler_server.state.config;
-            let scheduler_grpc_server =
-                SchedulerGrpcServer::new(scheduler_server.clone())
-                    .max_encoding_message_size(
-                        config.grpc_server_max_encoding_message_size as usize,
-                    )
-                    .max_decoding_message_size(
-                        config.grpc_server_max_decoding_message_size as usize,
-                    );
-
-            let keda_scaler = 
ExternalScalerServer::new(scheduler_server.clone());
-
-            let tonic_builder = create_grpc_server()
-                .add_service(scheduler_grpc_server)
-                .add_service(keda_scaler);
+    let config = &scheduler_server.state.config;
+    let scheduler_grpc_server = 
SchedulerGrpcServer::new(scheduler_server.clone())
+        
.max_encoding_message_size(config.grpc_server_max_encoding_message_size as 
usize)
+        
.max_decoding_message_size(config.grpc_server_max_decoding_message_size as 
usize);
 
-            #[cfg(feature = "flight-sql")]
-            let tonic_builder = 
tonic_builder.add_service(FlightServiceServer::new(
-                FlightSqlServiceImpl::new(scheduler_server.clone()),
-            ));
+    let keda_scaler = ExternalScalerServer::new(scheduler_server.clone());
 
-            let mut tonic = tonic_builder.into_service();
+    let tonic_builder = create_grpc_server()
+        .add_service(scheduler_grpc_server)
+        .add_service(keda_scaler);
 
-            let mut warp = warp::service(get_routes(scheduler_server.clone()));
+    #[cfg(feature = "flight-sql")]
+    let tonic_builder = tonic_builder.add_service(FlightServiceServer::new(
+        FlightSqlServiceImpl::new(scheduler_server.clone()),
+    ));
 
-            let connect_info = request.connect_info();
-            future::ok::<_, Infallible>(tower::service_fn(
-                move |req: hyper::Request<hyper::Body>| {
-                    // Set the connect info from hyper to tonic
-                    let (mut parts, body) = req.into_parts();
-                    parts.extensions.insert(connect_info.clone());
-                    let req = http::Request::from_parts(parts, body);
+    let tonic = tonic_builder.into_service().into_router();
 
-                    if req.uri().path().starts_with("/api") {
-                        return Either::Left(
-                            warp.call(req)
-                                .map_ok(|res| res.map(EitherBody::Left))
-                                .map_err(Error::from),
-                        );
-                    }
+    let axum = get_routes(Arc::new(scheduler_server));
+    let merged = axum
+        .merge(tonic)
+        .into_make_service_with_connect_info::<SocketAddr>();
 
-                    Either::Right(
-                        tonic
-                            .call(req)
-                            .map_ok(|res| res.map(EitherBody::Right))
-                            .map_err(Error::from),
-                    )
-                },
-            ))
-        }))
+    axum::Server::bind(&addr)
+        .serve(merged)
         .await
-        .context("Could not start grpc server")
+        .map_err(Error::from)
 }
diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs 
b/ballista/scheduler/src/scheduler_server/grpc.rs
index 2d759fb7..6992bf75 100644
--- a/ballista/scheduler/src/scheduler_server/grpc.rs
+++ b/ballista/scheduler/src/scheduler_server/grpc.rs
@@ -15,10 +15,12 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use axum::extract::ConnectInfo;
 use ballista_core::config::{BallistaConfig, BALLISTA_JOB_NAME};
 use ballista_core::serde::protobuf::execute_query_params::{OptionalSessionId, 
Query};
 use std::collections::HashMap;
 use std::convert::TryInto;
+use std::net::SocketAddr;
 
 use ballista_core::serde::protobuf::executor_registration::OptionalHost;
 use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc;
@@ -70,7 +72,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
                 "Bad request because poll work is not supported for push-based 
task scheduling",
             ));
         }
-        let remote_addr = request.remote_addr();
+        let remote_addr = extract_connect_info(&request);
         if let PollWorkParams {
             metadata: Some(metadata),
             num_free_slots,
@@ -155,7 +157,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
         &self,
         request: Request<RegisterExecutorParams>,
     ) -> Result<Response<RegisterExecutorResult>, Status> {
-        let remote_addr = request.remote_addr();
+        let remote_addr = extract_connect_info(&request);
         if let RegisterExecutorParams {
             metadata: Some(metadata),
         } = request.into_inner()
@@ -191,7 +193,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
         &self,
         request: Request<HeartBeatParams>,
     ) -> Result<Response<HeartBeatResult>, Status> {
-        let remote_addr = request.remote_addr();
+        let remote_addr = extract_connect_info(&request);
         let HeartBeatParams {
             executor_id,
             metrics,
@@ -634,6 +636,13 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
     }
 }
 
+fn extract_connect_info<T>(request: &Request<T>) -> 
Option<ConnectInfo<SocketAddr>> {
+    request
+        .extensions()
+        .get::<ConnectInfo<SocketAddr>>()
+        .cloned()
+}
+
 #[cfg(all(test, feature = "sled"))]
 mod test {
     use std::sync::Arc;
diff --git a/ballista/scheduler/src/scheduler_server/mod.rs 
b/ballista/scheduler/src/scheduler_server/mod.rs
index e6525f18..c2bf657b 100644
--- a/ballista/scheduler/src/scheduler_server/mod.rs
+++ b/ballista/scheduler/src/scheduler_server/mod.rs
@@ -345,7 +345,7 @@ mod test {
     use datafusion::functions_aggregate::sum::sum;
     use datafusion::logical_expr::{col, LogicalPlan};
 
-    use datafusion::test_util::scan_empty;
+    use datafusion::test_util::scan_empty_with_partitions;
     use datafusion_proto::protobuf::LogicalPlanNode;
     use datafusion_proto::protobuf::PhysicalPlanNode;
 
@@ -700,7 +700,9 @@ mod test {
             Field::new("gmv", DataType::UInt64, false),
         ]);
 
-        scan_empty(None, &schema, Some(vec![0, 1]))
+        // partitions need to be > 1 for the datafusion's optimizer to insert 
a repartition node
+        // behavior changed with: 
https://github.com/apache/datafusion/pull/11875
+        scan_empty_with_partitions(None, &schema, Some(vec![0, 1]), 2)
             .unwrap()
             .aggregate(vec![col("id")], vec![sum(col("gmv"))])
             .unwrap()
diff --git a/ballista/scheduler/src/state/execution_graph.rs 
b/ballista/scheduler/src/state/execution_graph.rs
index 9ee95d67..333545d3 100644
--- a/ballista/scheduler/src/state/execution_graph.rs
+++ b/ballista/scheduler/src/state/execution_graph.rs
@@ -1694,7 +1694,9 @@ mod test {
 
     use crate::state::execution_graph::ExecutionGraph;
     use crate::test_utils::{
-        mock_completed_task, mock_executor, mock_failed_task, 
test_aggregation_plan,
+        mock_completed_task, mock_executor, mock_failed_task,
+        revive_graph_and_complete_next_stage,
+        revive_graph_and_complete_next_stage_with_executor, 
test_aggregation_plan,
         test_coalesce_plan, test_join_plan, test_two_aggregations_plan,
         test_union_all_plan, test_union_plan,
     };
@@ -1793,19 +1795,13 @@ mod test {
         join_graph.revive();
 
         assert_eq!(join_graph.stage_count(), 4);
-        assert_eq!(join_graph.available_tasks(), 2);
+        assert_eq!(join_graph.available_tasks(), 4);
 
         // Complete the first stage
-        if let Some(task) = join_graph.pop_next_task(&executor1.id)? {
-            let task_status = mock_completed_task(task, &executor1.id);
-            join_graph.update_task_status(&executor1, vec![task_status], 1, 
1)?;
-        }
+        revive_graph_and_complete_next_stage_with_executor(&mut join_graph, 
&executor1)?;
 
         // Complete the second stage
-        if let Some(task) = join_graph.pop_next_task(&executor2.id)? {
-            let task_status = mock_completed_task(task, &executor2.id);
-            join_graph.update_task_status(&executor2, vec![task_status], 1, 
1)?;
-        }
+        revive_graph_and_complete_next_stage_with_executor(&mut join_graph, 
&executor2)?;
 
         join_graph.revive();
         // There are 4 tasks pending schedule for the 3rd stage
@@ -1823,7 +1819,7 @@ mod test {
 
         // Two stages were reset, 1 Running stage rollback to Unresolved and 1 
Completed stage move to Running
         assert_eq!(reset.0.len(), 2);
-        assert_eq!(join_graph.available_tasks(), 1);
+        assert_eq!(join_graph.available_tasks(), 2);
 
         drain_tasks(&mut join_graph)?;
         assert!(join_graph.is_successful(), "Failed to complete join plan");
@@ -1844,19 +1840,19 @@ mod test {
         join_graph.revive();
 
         assert_eq!(join_graph.stage_count(), 4);
-        assert_eq!(join_graph.available_tasks(), 2);
+        assert_eq!(join_graph.available_tasks(), 4);
 
         // Complete the first stage
-        if let Some(task) = join_graph.pop_next_task(&executor1.id)? {
-            let task_status = mock_completed_task(task, &executor1.id);
-            join_graph.update_task_status(&executor1, vec![task_status], 1, 
1)?;
-        }
+        assert_eq!(revive_graph_and_complete_next_stage(&mut join_graph)?, 2);
 
         // Complete the second stage
-        if let Some(task) = join_graph.pop_next_task(&executor2.id)? {
-            let task_status = mock_completed_task(task, &executor2.id);
-            join_graph.update_task_status(&executor2, vec![task_status], 1, 
1)?;
-        }
+        assert_eq!(
+            revive_graph_and_complete_next_stage_with_executor(
+                &mut join_graph,
+                &executor2
+            )?,
+            2
+        );
 
         // There are 0 tasks pending schedule now
         assert_eq!(join_graph.available_tasks(), 0);
@@ -1865,7 +1861,7 @@ mod test {
 
         // Two stages were reset, 1 Resolved stage rollback to Unresolved and 
1 Completed stage move to Running
         assert_eq!(reset.0.len(), 2);
-        assert_eq!(join_graph.available_tasks(), 1);
+        assert_eq!(join_graph.available_tasks(), 2);
 
         drain_tasks(&mut join_graph)?;
         assert!(join_graph.is_successful(), "Failed to complete join plan");
@@ -1886,13 +1882,10 @@ mod test {
         agg_graph.revive();
 
         assert_eq!(agg_graph.stage_count(), 2);
-        assert_eq!(agg_graph.available_tasks(), 1);
+        assert_eq!(agg_graph.available_tasks(), 2);
 
         // Complete the first stage
-        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
-            let task_status = mock_completed_task(task, &executor1.id);
-            agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
-        }
+        revive_graph_and_complete_next_stage_with_executor(&mut agg_graph, 
&executor1)?;
 
         // 1st task in the second stage
         if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
@@ -1920,12 +1913,12 @@ mod test {
 
         // Two stages were reset, 1 Running stage rollback to Unresolved and 1 
Completed stage move to Running
         assert_eq!(reset.0.len(), 2);
-        assert_eq!(agg_graph.available_tasks(), 1);
+        assert_eq!(agg_graph.available_tasks(), 2);
 
         // Call the reset again
         let reset = agg_graph.reset_stages_on_lost_executor(&executor1.id)?;
         assert_eq!(reset.0.len(), 0);
-        assert_eq!(agg_graph.available_tasks(), 1);
+        assert_eq!(agg_graph.available_tasks(), 2);
 
         drain_tasks(&mut agg_graph)?;
         assert!(agg_graph.is_successful(), "Failed to complete agg plan");
@@ -1935,24 +1928,20 @@ mod test {
 
     #[tokio::test]
     async fn test_do_not_retry_killed_task() -> Result<()> {
-        let executor1 = mock_executor("executor-id1".to_string());
-        let executor2 = mock_executor("executor-id2".to_string());
+        let executor = mock_executor("executor-id-123".to_string());
         let mut agg_graph = test_aggregation_plan(4).await;
         // Call revive to move the leaf Resolved stages to Running
         agg_graph.revive();
 
         // Complete the first stage
-        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
-            let task_status = mock_completed_task(task, &executor1.id);
-            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
-        }
+        revive_graph_and_complete_next_stage(&mut agg_graph)?;
 
         // 1st task in the second stage
-        let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
-        let task_status1 = mock_completed_task(task1, &executor2.id);
+        let task1 = agg_graph.pop_next_task(&executor.id)?.unwrap();
+        let task_status1 = mock_completed_task(task1, &executor.id);
 
         // 2rd task in the second stage
-        let task2 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
+        let task2 = agg_graph.pop_next_task(&executor.id)?.unwrap();
         let task_status2 = mock_failed_task(
             task2,
             FailedTask {
@@ -1964,7 +1953,7 @@ mod test {
         );
 
         agg_graph.update_task_status(
-            &executor2,
+            &executor,
             vec![task_status1, task_status2],
             4,
             4,
@@ -1983,24 +1972,20 @@ mod test {
 
     #[tokio::test]
     async fn test_max_task_failed_count() -> Result<()> {
-        let executor1 = mock_executor("executor-id1".to_string());
-        let executor2 = mock_executor("executor-id2".to_string());
+        let executor = mock_executor("executor-id2".to_string());
         let mut agg_graph = test_aggregation_plan(2).await;
         // Call revive to move the leaf Resolved stages to Running
         agg_graph.revive();
 
         // Complete the first stage
-        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
-            let task_status = mock_completed_task(task, &executor1.id);
-            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
-        }
+        revive_graph_and_complete_next_stage(&mut agg_graph)?;
 
         // 1st task in the second stage
-        let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
-        let task_status1 = mock_completed_task(task1, &executor2.id);
+        let task1 = agg_graph.pop_next_task(&executor.id)?.unwrap();
+        let task_status1 = mock_completed_task(task1, &executor.id);
 
         // 2rd task in the second stage, failed due to IOError
-        let task2 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
+        let task2 = agg_graph.pop_next_task(&executor.id)?.unwrap();
         let task_status2 = mock_failed_task(
             task2.clone(),
             FailedTask {
@@ -2012,7 +1997,7 @@ mod test {
         );
 
         agg_graph.update_task_status(
-            &executor2,
+            &executor,
             vec![task_status1, task_status2],
             4,
             4,
@@ -2023,7 +2008,7 @@ mod test {
         let mut last_attempt = 0;
         // 2rd task's attempts
         for attempt in 1..5 {
-            if let Some(task2_attempt) = 
agg_graph.pop_next_task(&executor2.id)? {
+            if let Some(task2_attempt) = 
agg_graph.pop_next_task(&executor.id)? {
                 assert_eq!(
                     task2_attempt.partition.partition_id,
                     task2.partition.partition_id
@@ -2041,7 +2026,7 @@ mod test {
                         )),
                     },
                 );
-                agg_graph.update_task_status(&executor2, vec![task_status], 4, 
4)?;
+                agg_graph.update_task_status(&executor, vec![task_status], 4, 
4)?;
             }
         }
 
@@ -2075,10 +2060,7 @@ mod test {
         agg_graph.revive();
 
         // Complete the Stage 1
-        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
-            let task_status = mock_completed_task(task, &executor1.id);
-            agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
-        }
+        revive_graph_and_complete_next_stage_with_executor(&mut agg_graph, 
&executor1)?;
 
         // 1st task in the Stage 2
         if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
@@ -2103,13 +2085,10 @@ mod test {
 
         // Two stages were reset, Stage 2 rollback to Unresolved and Stage 1 
move to Running
         assert_eq!(reset.0.len(), 2);
-        assert_eq!(agg_graph.available_tasks(), 1);
+        assert_eq!(agg_graph.available_tasks(), 2);
 
         // Complete the Stage 1 again
-        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
-            let task_status = mock_completed_task(task, &executor1.id);
-            agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
-        }
+        revive_graph_and_complete_next_stage_with_executor(&mut agg_graph, 
&executor1)?;
 
         // Stage 2 move to Running
         agg_graph.revive();
@@ -2148,10 +2127,7 @@ mod test {
         agg_graph.revive();
 
         // Complete the Stage 1
-        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
-            let task_status = mock_completed_task(task, &executor1.id);
-            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
-        }
+        revive_graph_and_complete_next_stage(&mut agg_graph)?;
 
         // 1st task in the Stage 2
         let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
@@ -2198,7 +2174,7 @@ mod test {
         let running_stage = agg_graph.running_stages();
         assert_eq!(running_stage.len(), 1);
         assert_eq!(running_stage[0], 1);
-        assert_eq!(agg_graph.available_tasks(), 1);
+        assert_eq!(agg_graph.available_tasks(), 2);
 
         drain_tasks(&mut agg_graph)?;
         assert!(agg_graph.is_successful(), "Failed to complete agg plan");
@@ -2216,10 +2192,7 @@ mod test {
         assert_eq!(agg_graph.stage_count(), 3);
 
         // Complete the Stage 1
-        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
-            let task_status = mock_completed_task(task, &executor1.id);
-            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
-        }
+        revive_graph_and_complete_next_stage(&mut agg_graph)?;
 
         // Complete the Stage 2, 5 tasks run on executor_2 and 3 tasks run on 
executor_1
         for _i in 0..5 {
@@ -2283,10 +2256,7 @@ mod test {
         agg_graph.revive();
 
         for attempt in 0..6 {
-            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
-                let task_status = mock_completed_task(task, &executor1.id);
-                agg_graph.update_task_status(&executor1, vec![task_status], 4, 
4)?;
-            }
+            revive_graph_and_complete_next_stage(&mut agg_graph)?;
 
             // 1rd task in the Stage 2, failed due to FetchPartitionError
             if let Some(task1) = agg_graph.pop_next_task(&executor2.id)? {
@@ -2318,7 +2288,7 @@ mod test {
                     let running_stage = agg_graph.running_stages();
                     assert_eq!(running_stage.len(), 1);
                     assert_eq!(running_stage[0], 1);
-                    assert_eq!(agg_graph.available_tasks(), 1);
+                    assert_eq!(agg_graph.available_tasks(), 2);
                 } else {
                     // Job is failed after exceeds the max_stage_failures
                     assert_eq!(stage_events.len(), 1);
@@ -2355,10 +2325,7 @@ mod test {
         assert_eq!(agg_graph.stage_count(), 3);
 
         // Complete the Stage 1
-        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
-            let task_status = mock_completed_task(task, &executor1.id);
-            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
-        }
+        revive_graph_and_complete_next_stage(&mut agg_graph)?;
 
         // Complete the Stage 2, 5 tasks run on executor_2, 2 tasks run on 
executor_1, 1 task runs on executor_3
         for _i in 0..5 {
@@ -2559,10 +2526,7 @@ mod test {
         assert_eq!(agg_graph.stage_count(), 3);
 
         // Complete the Stage 1
-        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
-            let task_status = mock_completed_task(task, &executor1.id);
-            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
-        }
+        revive_graph_and_complete_next_stage(&mut agg_graph)?;
 
         // Complete the Stage 2, 5 tasks run on executor_2, 3 tasks run on 
executor_1
         for _i in 0..5 {
@@ -2662,10 +2626,7 @@ mod test {
         assert_eq!(agg_graph.stage_count(), 3);
 
         // Complete the Stage 1
-        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
-            let task_status = mock_completed_task(task, &executor1.id);
-            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
-        }
+        revive_graph_and_complete_next_stage(&mut agg_graph)?;
 
         // Complete the Stage 2, 5 tasks run on executor_2, 3 tasks run on 
executor_1
         for _i in 0..5 {
@@ -2735,7 +2696,7 @@ mod test {
         let running_stage = agg_graph.running_stages();
         assert_eq!(running_stage.len(), 1);
         assert_eq!(running_stage[0], 1);
-        assert_eq!(agg_graph.available_tasks(), 1);
+        assert_eq!(agg_graph.available_tasks(), 2);
 
         // There are two failed stage attempts: Stage 2 and Stage 3
         assert_eq!(agg_graph.failed_stage_attempts.len(), 2);
@@ -2759,14 +2720,9 @@ mod test {
         let executor1 = mock_executor("executor-id1".to_string());
         let executor2 = mock_executor("executor-id2".to_string());
         let mut agg_graph = test_aggregation_plan(4).await;
-        // Call revive to move the leaf Resolved stages to Running
-        agg_graph.revive();
 
         // Complete the Stage 1
-        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
-            let task_status = mock_completed_task(task, &executor1.id);
-            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
-        }
+        revive_graph_and_complete_next_stage(&mut agg_graph)?;
 
         // 1st task in the Stage 2
         let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
diff --git a/ballista/scheduler/src/state/execution_graph_dot.rs 
b/ballista/scheduler/src/state/execution_graph_dot.rs
index 3a6dce7a..d5d7e7ae 100644
--- a/ballista/scheduler/src/state/execution_graph_dot.rs
+++ b/ballista/scheduler/src/state/execution_graph_dot.rs
@@ -432,13 +432,13 @@ mod tests {
         let expected = r#"digraph G {
        subgraph cluster0 {
                label = "Stage 1 [Resolved]";
-               stage_1_0 [shape=box, label="ShuffleWriter [0 partitions]"]
+               stage_1_0 [shape=box, label="ShuffleWriter [2 partitions]"]
                stage_1_0_0 [shape=box, label="MemoryExec"]
                stage_1_0_0 -> stage_1_0
        }
        subgraph cluster1 {
                label = "Stage 2 [Resolved]";
-               stage_2_0 [shape=box, label="ShuffleWriter [0 partitions]"]
+               stage_2_0 [shape=box, label="ShuffleWriter [2 partitions]"]
                stage_2_0_0 [shape=box, label="MemoryExec"]
                stage_2_0_0 -> stage_2_0
        }
@@ -462,7 +462,7 @@ filter_expr="]
        }
        subgraph cluster3 {
                label = "Stage 4 [Resolved]";
-               stage_4_0 [shape=box, label="ShuffleWriter [0 partitions]"]
+               stage_4_0 [shape=box, label="ShuffleWriter [2 partitions]"]
                stage_4_0_0 [shape=box, label="MemoryExec"]
                stage_4_0_0 -> stage_4_0
        }
@@ -531,19 +531,19 @@ filter_expr="]
         let expected = r#"digraph G {
        subgraph cluster0 {
                label = "Stage 1 [Resolved]";
-               stage_1_0 [shape=box, label="ShuffleWriter [0 partitions]"]
+               stage_1_0 [shape=box, label="ShuffleWriter [2 partitions]"]
                stage_1_0_0 [shape=box, label="MemoryExec"]
                stage_1_0_0 -> stage_1_0
        }
        subgraph cluster1 {
                label = "Stage 2 [Resolved]";
-               stage_2_0 [shape=box, label="ShuffleWriter [0 partitions]"]
+               stage_2_0 [shape=box, label="ShuffleWriter [2 partitions]"]
                stage_2_0_0 [shape=box, label="MemoryExec"]
                stage_2_0_0 -> stage_2_0
        }
        subgraph cluster2 {
                label = "Stage 3 [Resolved]";
-               stage_3_0 [shape=box, label="ShuffleWriter [0 partitions]"]
+               stage_3_0 [shape=box, label="ShuffleWriter [2 partitions]"]
                stage_3_0_0 [shape=box, label="MemoryExec"]
                stage_3_0_0 -> stage_3_0
        }
@@ -635,7 +635,7 @@ filter_expr="]
             Field::new("a", DataType::UInt32, false),
             Field::new("b", DataType::UInt32, false),
         ]));
-        let table = Arc::new(MemTable::try_new(schema.clone(), vec![])?);
+        let table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![], 
vec![]])?);
         ctx.register_table("foo", table.clone())?;
         ctx.register_table("bar", table.clone())?;
         ctx.register_table("baz", table)?;
@@ -660,7 +660,8 @@ filter_expr="]
         let ctx = SessionContext::new_with_config(config);
         let schema =
             Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, 
false)]));
-        let table = Arc::new(MemTable::try_new(schema.clone(), vec![])?);
+        // we specify the input partitions to be > 1 because of 
https://github.com/apache/datafusion/issues/12611
+        let table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![], 
vec![]])?);
         ctx.register_table("foo", table.clone())?;
         ctx.register_table("bar", table.clone())?;
         ctx.register_table("baz", table)?;
diff --git a/ballista/scheduler/src/test_utils.rs 
b/ballista/scheduler/src/test_utils.rs
index 59e6a875..5e5dee12 100644
--- a/ballista/scheduler/src/test_utils.rs
+++ b/ballista/scheduler/src/test_utils.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 use ballista_core::error::{BallistaError, Result};
+use datafusion::catalog::Session;
 use std::any::Any;
 use std::collections::HashMap;
 use std::future::Future;
@@ -44,19 +45,18 @@ use ballista_core::serde::{protobuf, BallistaCodec};
 use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
 use datafusion::common::DataFusionError;
 use datafusion::datasource::{TableProvider, TableType};
-use datafusion::execution::context::{SessionConfig, SessionContext, 
SessionState};
-use datafusion::functions_aggregate::sum::sum;
-use datafusion::logical_expr::expr::Sort;
-use datafusion::logical_expr::{Expr, LogicalPlan};
+use datafusion::execution::context::{SessionConfig, SessionContext};
+use datafusion::functions_aggregate::{count::count, sum::sum};
+use datafusion::logical_expr::{Expr, LogicalPlan, SortExpr};
 use datafusion::physical_plan::display::DisplayableExecutionPlan;
 use datafusion::physical_plan::ExecutionPlan;
-use datafusion::prelude::{col, count, CsvReadOptions, JoinType};
-use datafusion::test_util::scan_empty;
+use datafusion::prelude::{col, CsvReadOptions, JoinType};
+use datafusion::test_util::scan_empty_with_partitions;
 
 use crate::cluster::BallistaCluster;
 use crate::scheduler_server::event::QueryStageSchedulerEvent;
 
-use crate::state::execution_graph::{ExecutionGraph, TaskDescription};
+use crate::state::execution_graph::{ExecutionGraph, ExecutionStage, 
TaskDescription};
 use ballista_core::utils::default_session_builder;
 use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
 use parking_lot::Mutex;
@@ -89,7 +89,7 @@ impl TableProvider for ExplodingTableProvider {
 
     async fn scan(
         &self,
-        _ctx: &SessionState,
+        _ctx: &dyn Session,
         _projection: Option<&Vec<usize>>,
         _filters: &[Expr],
         _limit: Option<usize>,
@@ -783,6 +783,47 @@ pub fn assert_failed_event(job_id: &str, collector: 
&TestMetricsCollector) {
     assert!(found, "{}", "Expected failed event for job {job_id}");
 }
 
+pub fn revive_graph_and_complete_next_stage(graph: &mut ExecutionGraph) -> 
Result<usize> {
+    let executor = mock_executor("executor-id1".to_string());
+    revive_graph_and_complete_next_stage_with_executor(graph, &executor)
+}
+
+pub fn revive_graph_and_complete_next_stage_with_executor(
+    graph: &mut ExecutionGraph,
+    executor: &ExecutorMetadata,
+) -> Result<usize> {
+    graph.revive();
+
+    // find the num_available_tasks of the next running stage
+    let num_available_tasks = graph
+        .stages()
+        .iter()
+        .map(|(_stage_id, stage)| {
+            if let ExecutionStage::Running(stage) = stage {
+                stage
+                    .task_infos
+                    .iter()
+                    .filter(|info| info.is_none())
+                    .count()
+            } else {
+                0
+            }
+        })
+        .find(|num_available_tasks| num_available_tasks > &0)
+        .unwrap();
+
+    if num_available_tasks > 0 {
+        for _ in 0..num_available_tasks {
+            if let Some(task) = graph.pop_next_task(&executor.id).unwrap() {
+                let task_status = mock_completed_task(task, &executor.id);
+                graph.update_task_status(executor, vec![task_status], 1, 1)?;
+            }
+        }
+    }
+
+    Ok(num_available_tasks)
+}
+
 pub async fn test_aggregation_plan(partition: usize) -> ExecutionGraph {
     test_aggregation_plan_with_job_id(partition, "job").await
 }
@@ -800,7 +841,8 @@ pub async fn test_aggregation_plan_with_job_id(
         Field::new("gmv", DataType::UInt64, false),
     ]);
 
-    let logical_plan = scan_empty(None, &schema, Some(vec![0, 1]))
+    // we specify the input partitions to be > 1 because of 
https://github.com/apache/datafusion/issues/12611
+    let logical_plan = scan_empty_with_partitions(None, &schema, Some(vec![0, 
1]), 2)
         .unwrap()
         .aggregate(vec![col("id")], vec![sum(col("gmv"))])
         .unwrap()
@@ -833,7 +875,8 @@ pub async fn test_two_aggregations_plan(partition: usize) 
-> ExecutionGraph {
         Field::new("gmv", DataType::UInt64, false),
     ]);
 
-    let logical_plan = scan_empty(None, &schema, Some(vec![0, 1, 2]))
+    // we specify the input partitions to be > 1 because of 
https://github.com/apache/datafusion/issues/12611
+    let logical_plan = scan_empty_with_partitions(None, &schema, Some(vec![0, 
1, 2]), 2)
         .unwrap()
         .aggregate(vec![col("id"), col("name")], vec![sum(col("gmv"))])
         .unwrap()
@@ -867,7 +910,8 @@ pub async fn test_coalesce_plan(partition: usize) -> 
ExecutionGraph {
         Field::new("gmv", DataType::UInt64, false),
     ]);
 
-    let logical_plan = scan_empty(None, &schema, Some(vec![0, 1]))
+    // we specify the input partitions to be > 1 because of 
https://github.com/apache/datafusion/issues/12611
+    let logical_plan = scan_empty_with_partitions(None, &schema, Some(vec![0, 
1]), 2)
         .unwrap()
         .limit(0, Some(1))
         .unwrap()
@@ -898,14 +942,15 @@ pub async fn test_join_plan(partition: usize) -> 
ExecutionGraph {
         Field::new("gmv", DataType::UInt64, false),
     ]);
 
-    let left_plan = scan_empty(Some("left"), &schema, None).unwrap();
+    // we specify the input partitions to be > 1 because of 
https://github.com/apache/datafusion/issues/12611
+    let left_plan = scan_empty_with_partitions(Some("left"), &schema, None, 
2).unwrap();
 
-    let right_plan = scan_empty(Some("right"), &schema, None)
+    let right_plan = scan_empty_with_partitions(Some("right"), &schema, None, 
2)
         .unwrap()
         .build()
         .unwrap();
 
-    let sort_expr = Expr::Sort(Sort::new(Box::new(col("id")), false, false));
+    let sort_expr = Expr::Sort(SortExpr::new(Box::new(col("id")), false, 
false));
 
     let logical_plan = left_plan
         .join(right_plan, JoinType::Inner, (vec!["id"], vec!["id"]), None)
diff --git a/examples/Cargo.toml b/examples/Cargo.toml
index e41e6051..3fd07740 100644
--- a/examples/Cargo.toml
+++ b/examples/Cargo.toml
@@ -38,7 +38,7 @@ ballista = { path = "../ballista/client", version = "0.12.0" }
 datafusion = { workspace = true }
 futures = "0.3"
 num_cpus = "1.13.0"
-prost = "0.12"
+prost = { workspace = true }
 tokio = { version = "1.0", features = [
     "macros",
     "rt",
@@ -46,4 +46,4 @@ tokio = { version = "1.0", features = [
     "sync",
     "parking_lot"
 ] }
-tonic = "0.10"
+tonic = { workspace = true }
diff --git a/python/.cargo/config.toml b/python/.cargo/config.toml
new file mode 100644
index 00000000..d47f983e
--- /dev/null
+++ b/python/.cargo/config.toml
@@ -0,0 +1,11 @@
+[target.x86_64-apple-darwin]
+rustflags = [
+  "-C", "link-arg=-undefined",
+  "-C", "link-arg=dynamic_lookup",
+]
+
+[target.aarch64-apple-darwin]
+rustflags = [
+  "-C", "link-arg=-undefined",
+  "-C", "link-arg=dynamic_lookup",
+]
diff --git a/python/Cargo.toml b/python/Cargo.toml
index 2b2dff41..eb662cb1 100644
--- a/python/Cargo.toml
+++ b/python/Cargo.toml
@@ -33,13 +33,12 @@ publish = false
 async-trait = "0.1.77"
 ballista = { path = "../ballista/client", version = "0.12.0" }
 ballista-core = { path = "../ballista/core", version = "0.12.0" }
-datafusion = "35.0.0"
-datafusion-proto = "35.0.0"
+datafusion = "41.0.0"
+datafusion-proto = "41.0.0"
+datafusion-python = "41.0.0"
 
-# we need to use a recent build of ADP that has a public PyDataFrame
-datafusion-python = { git = 
"https://github.com/apache/arrow-datafusion-python";, rev = 
"5296c0cfcf8e6fcb654d5935252469bf04f929e9" }
-
-pyo3 = { version = "0.20", features = ["extension-module", "abi3", 
"abi3-py38"] }
+pyo3 = { version = "0.21", features = ["extension-module", "abi3", 
"abi3-py38"] }
+pyo3-log = "0.11.0"
 tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread", 
"sync"] }
 
 [lib]
diff --git a/python/src/context.rs b/python/src/context.rs
index 0d0231c6..be7dd610 100644
--- a/python/src/context.rs
+++ b/python/src/context.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 use crate::utils::to_pyerr;
+use datafusion::logical_expr::SortExpr;
 use pyo3::exceptions::PyValueError;
 use pyo3::prelude::*;
 use std::path::PathBuf;
@@ -30,7 +31,7 @@ use datafusion_python::context::{
 };
 use datafusion_python::dataframe::PyDataFrame;
 use datafusion_python::errors::DataFusionError;
-use datafusion_python::expr::PyExpr;
+use datafusion_python::expr::sort_expr::PySortExpr;
 use datafusion_python::sql::logical::PyLogicalPlan;
 use datafusion_python::utils::wait_for_future;
 
@@ -187,7 +188,7 @@ impl PySessionContext {
         file_extension: &str,
         skip_metadata: bool,
         schema: Option<PyArrowType<Schema>>,
-        file_sort_order: Option<Vec<Vec<PyExpr>>>,
+        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
         py: Python,
     ) -> PyResult<PyDataFrame> {
         let mut options = ParquetReadOptions::default()
@@ -199,7 +200,14 @@ impl PySessionContext {
         options.file_sort_order = file_sort_order
             .unwrap_or_default()
             .into_iter()
-            .map(|e| e.into_iter().map(|f| f.into()).collect())
+            .map(|e| {
+                e.into_iter()
+                    .map(|f| {
+                        let sort_expr: SortExpr = f.into();
+                        *sort_expr.expr
+                    })
+                    .collect()
+            })
             .collect();
 
         let result = self.ctx.read_parquet(path, options);
@@ -299,7 +307,7 @@ impl PySessionContext {
         file_extension: &str,
         skip_metadata: bool,
         schema: Option<PyArrowType<Schema>>,
-        file_sort_order: Option<Vec<Vec<PyExpr>>>,
+        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
         py: Python,
     ) -> PyResult<()> {
         let mut options = ParquetReadOptions::default()
@@ -311,7 +319,14 @@ impl PySessionContext {
         options.file_sort_order = file_sort_order
             .unwrap_or_default()
             .into_iter()
-            .map(|e| e.into_iter().map(|f| f.into()).collect())
+            .map(|e| {
+                e.into_iter()
+                    .map(|f| {
+                        let sort_expr: SortExpr = f.into();
+                        *sort_expr.expr
+                    })
+                    .collect()
+            })
             .collect();
 
         let result = self.ctx.register_parquet(name, path, options);
diff --git a/python/src/lib.rs b/python/src/lib.rs
index 04cf232a..5fbd2491 100644
--- a/python/src/lib.rs
+++ b/python/src/lib.rs
@@ -22,7 +22,8 @@ mod utils;
 pub use crate::context::PySessionContext;
 
 #[pymodule]
-fn pyballista_internal(_py: Python, m: &PyModule) -> PyResult<()> {
+fn pyballista_internal(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
+    pyo3_log::init();
     // Ballista structs
     m.add_class::<PySessionContext>()?;
     // DataFusion structs


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to