This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git
The following commit(s) were added to refs/heads/main by this push:
new 2f223db2 Upgrade to Datafusion 41 (#1062)
2f223db2 is described below
commit 2f223db21557c15080bf865ac692d276b8f0b770
Author: Baris Palaska <[email protected]>
AuthorDate: Fri Sep 27 19:16:15 2024 +0100
Upgrade to Datafusion 41 (#1062)
* upgrade dependencies, some fixes, still wip
compiles
add license
Update datafusion protobuf definitions (#1057)
* update datafusion proto defs
* allow optionals in proto3
update docker environment for higher protoc version
* runs
* test e2e, fix python
* rm unnecessary dependency
* create BallistaLogicalExtensionCodec that can decode/encode file formats,
fix some tests
* fix tests
* clippy, tomlfmt
* fix grpc connect info extract
* extract into method, remove unnecessary log
* datafusion to 41, adjust other deps
---
Cargo.toml | 26 +-
ballista-cli/Cargo.toml | 4 +-
ballista-cli/src/main.rs | 28 +-
ballista/client/README.md | 3 +-
ballista/client/src/context.rs | 32 +--
ballista/core/Cargo.toml | 6 +-
ballista/core/src/config.rs | 14 +-
.../core/src/execution_plans/distributed_query.rs | 4 +
.../core/src/execution_plans/shuffle_reader.rs | 4 +
.../core/src/execution_plans/shuffle_writer.rs | 4 +
.../core/src/execution_plans/unresolved_shuffle.rs | 4 +
ballista/core/src/serde/mod.rs | 104 +++++++-
ballista/core/src/serde/scheduler/mod.rs | 5 +
ballista/core/src/utils.rs | 27 +-
ballista/executor/Cargo.toml | 1 -
ballista/executor/src/collect.rs | 4 +
ballista/executor/src/executor.rs | 22 +-
ballista/scheduler/Cargo.toml | 17 +-
ballista/scheduler/src/api/handlers.rs | 297 ++++++++++++---------
ballista/scheduler/src/api/mod.rs | 153 +++--------
ballista/scheduler/src/bin/main.rs | 13 +-
ballista/scheduler/src/cluster/mod.rs | 26 +-
ballista/scheduler/src/config.rs | 6 +-
ballista/scheduler/src/planner.rs | 2 +
ballista/scheduler/src/scheduler_process.rs | 85 ++----
ballista/scheduler/src/scheduler_server/grpc.rs | 15 +-
ballista/scheduler/src/scheduler_server/mod.rs | 6 +-
ballista/scheduler/src/state/execution_graph.rs | 140 ++++------
.../scheduler/src/state/execution_graph_dot.rs | 17 +-
ballista/scheduler/src/test_utils.rs | 73 ++++-
examples/Cargo.toml | 4 +-
python/.cargo/config.toml | 11 +
python/Cargo.toml | 11 +-
python/src/context.rs | 25 +-
python/src/lib.rs | 3 +-
35 files changed, 653 insertions(+), 543 deletions(-)
diff --git a/Cargo.toml b/Cargo.toml
index 3c451885..ea1c8321 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -21,19 +21,23 @@ members = ["ballista-cli", "ballista/cache",
"ballista/client", "ballista/core",
resolver = "2"
[workspace.dependencies]
-arrow = { version = "52.0.0", features = ["ipc_compression"] }
-arrow-flight = { version = "52.0.0", features = ["flight-sql-experimental"] }
-arrow-schema = { version = "52.0.0", default-features = false }
+arrow = { version = "52.2.0", features = ["ipc_compression"] }
+arrow-flight = { version = "52.2.0", features = ["flight-sql-experimental"] }
+arrow-schema = { version = "52.2.0", default-features = false }
+clap = { version = "3", features = ["derive", "cargo"] }
configure_me = { version = "0.4.0" }
configure_me_codegen = { version = "0.4.4" }
-datafusion = "39.0.0"
-datafusion-cli = "39.0.0"
-datafusion-proto = "39.0.0"
-datafusion-proto-common = "39.0.0"
-object_store = "0.10.1"
-sqlparser = "0.47.0"
-tonic = { version = "0.11" }
-tonic-build = { version = "0.11", default-features = false, features = [
+# bump directly to datafusion v43 to avoid the serde bug on v42
(https://github.com/apache/datafusion/pull/12626)
+datafusion = "41.0.0"
+datafusion-cli = "41.0.0"
+datafusion-proto = "41.0.0"
+datafusion-proto-common = "41.0.0"
+object_store = "0.10.2"
+prost = "0.12.0"
+prost-types = "0.12.0"
+sqlparser = "0.49.0"
+tonic = { version = "0.11.0" }
+tonic-build = { version = "0.11.0", default-features = false, features = [
"transport",
"prost"
] }
diff --git a/ballista-cli/Cargo.toml b/ballista-cli/Cargo.toml
index e07ad279..dc8ff7cb 100644
--- a/ballista-cli/Cargo.toml
+++ b/ballista-cli/Cargo.toml
@@ -30,14 +30,14 @@ readme = "README.md"
[dependencies]
ballista = { path = "../ballista/client", version = "0.12.0", features =
["standalone"] }
-clap = { version = "3", features = ["derive", "cargo"] }
+clap = { workspace = true }
datafusion = { workspace = true }
datafusion-cli = { workspace = true }
dirs = "5.0.1"
env_logger = "0.10"
mimalloc = { version = "0.1", default-features = false }
num_cpus = "1.13.0"
-rustyline = "11.0"
+rustyline = "11.0.0"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread",
"sync", "parking_lot"] }
[features]
diff --git a/ballista-cli/src/main.rs b/ballista-cli/src/main.rs
index a8055c87..6aeecd6c 100644
--- a/ballista-cli/src/main.rs
+++ b/ballista-cli/src/main.rs
@@ -36,7 +36,7 @@ struct Args {
short = 'p',
long,
help = "Path to your data, default to current directory",
- validator(is_valid_data_dir)
+ value_parser(parse_valid_data_dir)
)]
data_path: Option<String>,
@@ -44,14 +44,14 @@ struct Args {
short = 'c',
long,
help = "The batch size of each query, or use Ballista default",
- validator(is_valid_batch_size)
+ value_parser(parse_batch_size)
)]
batch_size: Option<usize>,
#[clap(
long,
help = "The max concurrent tasks, only for Ballista local mode.
Default: all available cores",
- validator(is_valid_concurrent_tasks_size)
+ value_parser(parse_valid_concurrent_tasks_size)
)]
concurrent_tasks: Option<usize>,
@@ -60,7 +60,7 @@ struct Args {
long,
multiple_values = true,
help = "Execute commands from file(s), then exit",
- validator(is_valid_file)
+ value_parser(parse_valid_file)
)]
file: Vec<String>,
@@ -69,12 +69,12 @@ struct Args {
long,
multiple_values = true,
help = "Run the provided files on startup instead of ~/.ballistarc",
- validator(is_valid_file),
+ value_parser(parse_valid_file),
conflicts_with = "file"
)]
rc: Option<Vec<String>>,
- #[clap(long, arg_enum, default_value_t = PrintFormat::Table)]
+ #[clap(long, value_enum, default_value_t = PrintFormat::Table)]
format: PrintFormat,
#[clap(long, help = "Ballista scheduler host")]
@@ -168,32 +168,32 @@ pub async fn main() -> Result<()> {
Ok(())
}
-fn is_valid_file(dir: &str) -> std::result::Result<(), String> {
+fn parse_valid_file(dir: &str) -> std::result::Result<String, String> {
if Path::new(dir).is_file() {
- Ok(())
+ Ok(dir.to_string())
} else {
Err(format!("Invalid file '{dir}'"))
}
}
-fn is_valid_data_dir(dir: &str) -> std::result::Result<(), String> {
+fn parse_valid_data_dir(dir: &str) -> std::result::Result<String, String> {
if Path::new(dir).is_dir() {
- Ok(())
+ Ok(dir.to_string())
} else {
Err(format!("Invalid data directory '{dir}'"))
}
}
-fn is_valid_batch_size(size: &str) -> std::result::Result<(), String> {
+fn parse_batch_size(size: &str) -> std::result::Result<usize, String> {
match size.parse::<usize>() {
- Ok(size) if size > 0 => Ok(()),
+ Ok(size) if size > 0 => Ok(size),
_ => Err(format!("Invalid batch size '{size}'")),
}
}
-fn is_valid_concurrent_tasks_size(size: &str) -> std::result::Result<(),
String> {
+fn parse_valid_concurrent_tasks_size(size: &str) -> std::result::Result<usize,
String> {
match size.parse::<usize>() {
- Ok(size) if size > 0 => Ok(()),
+ Ok(size) if size > 0 => Ok(size),
_ => Err(format!("Invalid concurrent_tasks size '{size}'")),
}
}
diff --git a/ballista/client/README.md b/ballista/client/README.md
index 19dc1439..ac65bc98 100644
--- a/ballista/client/README.md
+++ b/ballista/client/README.md
@@ -92,7 +92,8 @@ data set. Download the file and add it to the `testdata`
folder before running t
```rust,no_run
use ballista::prelude::*;
-use datafusion::prelude::{col, min, max, avg, sum, ParquetReadOptions};
+use datafusion::prelude::{col, ParquetReadOptions};
+use datafusion::functions_aggregate::{min_max::min, min_max::max, sum::sum,
average::avg};
#[tokio::main]
async fn main() -> Result<()> {
diff --git a/ballista/client/src/context.rs b/ballista/client/src/context.rs
index de22b777..269afc64 100644
--- a/ballista/client/src/context.rs
+++ b/ballista/client/src/context.rs
@@ -19,6 +19,7 @@
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::execution::context::DataFilePaths;
+use datafusion::sql::TableReference;
use log::info;
use parking_lot::Mutex;
use sqlparser::ast::Statement;
@@ -33,7 +34,6 @@ use ballista_core::utils::{
};
use datafusion_proto::protobuf::LogicalPlanNode;
-use datafusion::catalog::TableReference;
use datafusion::dataframe::DataFrame;
use datafusion::datasource::{source_as_provider, TableProvider};
use datafusion::error::{DataFusionError, Result};
@@ -791,7 +791,7 @@ mod standalone_tests {
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------+",
- "| MIN(test.id) |",
+ "| min(test.id) |",
"+--------------+",
"| 0 |",
"+--------------+",
@@ -802,7 +802,7 @@ mod standalone_tests {
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------+",
- "| MAX(test.id) |",
+ "| max(test.id) |",
"+--------------+",
"| 7 |",
"+--------------+",
@@ -818,7 +818,7 @@ mod standalone_tests {
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------+",
- "| SUM(test.id) |",
+ "| sum(test.id) |",
"+--------------+",
"| 28 |",
"+--------------+",
@@ -833,7 +833,7 @@ mod standalone_tests {
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------+",
- "| AVG(test.id) |",
+ "| avg(test.id) |",
"+--------------+",
"| 3.5 |",
"+--------------+",
@@ -849,7 +849,7 @@ mod standalone_tests {
let res = df.collect().await.unwrap();
let expected = vec![
"+----------------+",
- "| COUNT(test.id) |",
+ "| count(test.id) |",
"+----------------+",
"| 8 |",
"+----------------+",
@@ -867,7 +867,7 @@ mod standalone_tests {
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------------+",
- "| APPROX_DISTINCT(test.id) |",
+ "| approx_distinct(test.id) |",
"+--------------------------+",
"| 8 |",
"+--------------------------+",
@@ -885,7 +885,7 @@ mod standalone_tests {
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------------+",
- "| ARRAY_AGG(test.id) |",
+ "| array_agg(test.id) |",
"+--------------------------+",
"| [4, 5, 6, 7, 2, 3, 0, 1] |",
"+--------------------------+",
@@ -914,7 +914,7 @@ mod standalone_tests {
let res = df.collect().await.unwrap();
let expected = vec![
"+-------------------+",
- "| VAR_POP(test.id) |",
+ "| var_pop(test.id) |",
"+-------------------+",
"| 5.250000000000001 |",
"+-------------------+",
@@ -946,7 +946,7 @@ mod standalone_tests {
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------+",
- "| STDDEV(test.id) |",
+ "| stddev(test.id) |",
"+--------------------+",
"| 2.4494897427831783 |",
"+--------------------+",
@@ -960,7 +960,7 @@ mod standalone_tests {
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------+",
- "| STDDEV(test.id) |",
+ "| stddev(test.id) |",
"+--------------------+",
"| 2.4494897427831783 |",
"+--------------------+",
@@ -996,25 +996,27 @@ mod standalone_tests {
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------------------+",
- "| CORR(test.id,test.tinyint_col) |",
+ "| corr(test.id,test.tinyint_col) |",
"+--------------------------------+",
"| 0.21821789023599245 |",
"+--------------------------------+",
];
assert_result_eq(expected, &res);
}
+ // enable when upgrading Datafusion to > 42
+ #[ignore]
#[tokio::test]
async fn test_aggregate_approx_percentile() {
let context = create_test_context().await;
let df = context
- .sql("select approx_percentile_cont_with_weight(\"id\", 2, 0.5)
from test")
+ .sql("select approx_percentile_cont_with_weight(id, 2, 0.5) from
test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+-------------------------------------------------------------------+",
- "|
APPROX_PERCENTILE_CONT_WITH_WEIGHT(test.id,Int64(2),Float64(0.5)) |",
+ "|
approx_percentile_cont_with_weight(test.id,Int64(2),Float64(0.5)) |",
"+-------------------------------------------------------------------+",
"| 1
|",
"+-------------------------------------------------------------------+",
@@ -1028,7 +1030,7 @@ mod standalone_tests {
let res = df.collect().await.unwrap();
let expected = vec![
"+------------------------------------------------------+",
- "| APPROX_PERCENTILE_CONT(test.double_col,Float64(0.5)) |",
+ "| approx_percentile_cont(test.double_col,Float64(0.5)) |",
"+------------------------------------------------------+",
"| 7.574999999999999 |",
"+------------------------------------------------------+",
diff --git a/ballista/core/Cargo.toml b/ballista/core/Cargo.toml
index fccdd0ec..8a01f56f 100644
--- a/ballista/core/Cargo.toml
+++ b/ballista/core/Cargo.toml
@@ -51,7 +51,7 @@ async-trait = "0.1.41"
ballista-cache = { path = "../cache", version = "0.12.0" }
bytes = "1.0"
chrono = { version = "0.4", default-features = false }
-clap = { version = "3", features = ["derive", "cargo"] }
+clap = { workspace = true }
datafusion = { workspace = true }
datafusion-objectstore-hdfs = { version = "0.1.4", default-features = false,
optional = true }
datafusion-proto = { workspace = true }
@@ -68,8 +68,8 @@ once_cell = "1.9.0"
parking_lot = "0.12"
parse_arg = "0.1.3"
-prost = "0.12"
-prost-types = "0.12"
+prost = { workspace = true }
+prost-types = { workspace = true }
rand = "0.8"
serde = { version = "1", features = ["derive"] }
sqlparser = { workspace = true }
diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs
index 03c8f6b9..46424ecf 100644
--- a/ballista/core/src/config.rs
+++ b/ballista/core/src/config.rs
@@ -18,7 +18,7 @@
//! Ballista configuration
-use clap::ArgEnum;
+use clap::ValueEnum;
use core::fmt;
use std::collections::HashMap;
use std::result;
@@ -307,7 +307,7 @@ impl BallistaConfig {
// an enum used to configure the scheduler policy
// needs to be visible to code generated by configure_me
-#[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)]
+#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)]
pub enum TaskSchedulingPolicy {
PullStaged,
PushStaged,
@@ -317,7 +317,7 @@ impl std::str::FromStr for TaskSchedulingPolicy {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
- ArgEnum::from_str(s, true)
+ ValueEnum::from_str(s, true)
}
}
@@ -329,7 +329,7 @@ impl parse_arg::ParseArgFromStr for TaskSchedulingPolicy {
// an enum used to configure the log rolling policy
// needs to be visible to code generated by configure_me
-#[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)]
+#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)]
pub enum LogRotationPolicy {
Minutely,
Hourly,
@@ -341,7 +341,7 @@ impl std::str::FromStr for LogRotationPolicy {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
- ArgEnum::from_str(s, true)
+ ValueEnum::from_str(s, true)
}
}
@@ -353,7 +353,7 @@ impl parse_arg::ParseArgFromStr for LogRotationPolicy {
// an enum used to configure the source data cache policy
// needs to be visible to code generated by configure_me
-#[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)]
+#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)]
pub enum DataCachePolicy {
LocalDiskFile,
}
@@ -362,7 +362,7 @@ impl std::str::FromStr for DataCachePolicy {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
- ArgEnum::from_str(s, true)
+ ValueEnum::from_str(s, true)
}
}
diff --git a/ballista/core/src/execution_plans/distributed_query.rs
b/ballista/core/src/execution_plans/distributed_query.rs
index b96367bb..050ba877 100644
--- a/ballista/core/src/execution_plans/distributed_query.rs
+++ b/ballista/core/src/execution_plans/distributed_query.rs
@@ -154,6 +154,10 @@ impl<T: 'static + AsLogicalPlan> DisplayAs for
DistributedQueryExec<T> {
}
impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
+ fn name(&self) -> &str {
+ "DistributedQueryExec"
+ }
+
fn as_any(&self) -> &dyn Any {
self
}
diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs
b/ballista/core/src/execution_plans/shuffle_reader.rs
index 79dfe296..2f856b39 100644
--- a/ballista/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/core/src/execution_plans/shuffle_reader.rs
@@ -107,6 +107,10 @@ impl DisplayAs for ShuffleReaderExec {
}
impl ExecutionPlan for ShuffleReaderExec {
+ fn name(&self) -> &str {
+ "ShuffleReaderExec"
+ }
+
fn as_any(&self) -> &dyn Any {
self
}
diff --git a/ballista/core/src/execution_plans/shuffle_writer.rs
b/ballista/core/src/execution_plans/shuffle_writer.rs
index 7f21b18b..87e4feea 100644
--- a/ballista/core/src/execution_plans/shuffle_writer.rs
+++ b/ballista/core/src/execution_plans/shuffle_writer.rs
@@ -355,6 +355,10 @@ impl DisplayAs for ShuffleWriterExec {
}
impl ExecutionPlan for ShuffleWriterExec {
+ fn name(&self) -> &str {
+ "ShuffleWriterExec"
+ }
+
fn as_any(&self) -> &dyn Any {
self
}
diff --git a/ballista/core/src/execution_plans/unresolved_shuffle.rs
b/ballista/core/src/execution_plans/unresolved_shuffle.rs
index b3c30c0d..e227e2ac 100644
--- a/ballista/core/src/execution_plans/unresolved_shuffle.rs
+++ b/ballista/core/src/execution_plans/unresolved_shuffle.rs
@@ -82,6 +82,10 @@ impl DisplayAs for UnresolvedShuffleExec {
}
impl ExecutionPlan for UnresolvedShuffleExec {
+ fn name(&self) -> &str {
+ "UnresolvedShuffleExec"
+ }
+
fn as_any(&self) -> &dyn Any {
self
}
diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs
index 08208eed..2bb555d1 100644
--- a/ballista/core/src/serde/mod.rs
+++ b/ballista/core/src/serde/mod.rs
@@ -21,9 +21,13 @@
use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction};
use arrow_flight::sql::ProstMessageExt;
-use datafusion::common::DataFusionError;
+use datafusion::common::{DataFusionError, Result};
use datafusion::execution::FunctionRegistry;
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
+use datafusion_proto::logical_plan::file_formats::{
+ ArrowLogicalExtensionCodec, AvroLogicalExtensionCodec,
CsvLogicalExtensionCodec,
+ JsonLogicalExtensionCodec, ParquetLogicalExtensionCodec,
+};
use
datafusion_proto::physical_plan::from_proto::parse_protobuf_hash_partitioning;
use datafusion_proto::protobuf::proto_error;
use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
@@ -84,7 +88,7 @@ pub struct BallistaCodec<
impl Default for BallistaCodec {
fn default() -> Self {
Self {
- logical_extension_codec: Arc::new(DefaultLogicalExtensionCodec {}),
+ logical_extension_codec:
Arc::new(BallistaLogicalExtensionCodec::default()),
physical_extension_codec: Arc::new(BallistaPhysicalExtensionCodec
{}),
logical_plan_repr: PhantomData,
physical_plan_repr: PhantomData,
@@ -114,6 +118,102 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> BallistaCodec<T,
}
}
+#[derive(Debug)]
+pub struct BallistaLogicalExtensionCodec {
+ default_codec: Arc<dyn LogicalExtensionCodec>,
+ file_format_codecs: Vec<Arc<dyn LogicalExtensionCodec>>,
+}
+
+impl BallistaLogicalExtensionCodec {
+ fn try_any<T>(
+ &self,
+ mut f: impl FnMut(&dyn LogicalExtensionCodec) -> Result<T>,
+ ) -> Result<T> {
+ let mut last_err = None;
+ for codec in &self.file_format_codecs {
+ match f(codec.as_ref()) {
+ Ok(node) => return Ok(node),
+ Err(err) => last_err = Some(err),
+ }
+ }
+
+ Err(last_err.unwrap_or_else(|| {
+ DataFusionError::NotImplemented("Empty list of composed
codecs".to_owned())
+ }))
+ }
+}
+
+impl Default for BallistaLogicalExtensionCodec {
+ fn default() -> Self {
+ Self {
+ default_codec: Arc::new(DefaultLogicalExtensionCodec {}),
+ file_format_codecs: vec![
+ Arc::new(CsvLogicalExtensionCodec {}),
+ Arc::new(JsonLogicalExtensionCodec {}),
+ Arc::new(ParquetLogicalExtensionCodec {}),
+ Arc::new(ArrowLogicalExtensionCodec {}),
+ Arc::new(AvroLogicalExtensionCodec {}),
+ ],
+ }
+ }
+}
+
+impl LogicalExtensionCodec for BallistaLogicalExtensionCodec {
+ fn try_decode(
+ &self,
+ buf: &[u8],
+ inputs: &[datafusion::logical_expr::LogicalPlan],
+ ctx: &datafusion::prelude::SessionContext,
+ ) -> Result<datafusion::logical_expr::Extension> {
+ self.default_codec.try_decode(buf, inputs, ctx)
+ }
+
+ fn try_encode(
+ &self,
+ node: &datafusion::logical_expr::Extension,
+ buf: &mut Vec<u8>,
+ ) -> Result<()> {
+ self.default_codec.try_encode(node, buf)
+ }
+
+ fn try_decode_table_provider(
+ &self,
+ buf: &[u8],
+ table_ref: &datafusion::sql::TableReference,
+ schema: datafusion::arrow::datatypes::SchemaRef,
+ ctx: &datafusion::prelude::SessionContext,
+ ) -> Result<Arc<dyn datafusion::catalog::TableProvider>> {
+ self.default_codec
+ .try_decode_table_provider(buf, table_ref, schema, ctx)
+ }
+
+ fn try_encode_table_provider(
+ &self,
+ table_ref: &datafusion::sql::TableReference,
+ node: Arc<dyn datafusion::catalog::TableProvider>,
+ buf: &mut Vec<u8>,
+ ) -> Result<()> {
+ self.default_codec
+ .try_encode_table_provider(table_ref, node, buf)
+ }
+
+ fn try_decode_file_format(
+ &self,
+ buf: &[u8],
+ ctx: &datafusion::prelude::SessionContext,
+ ) -> Result<Arc<dyn
datafusion::datasource::file_format::FileFormatFactory>> {
+ self.try_any(|codec| codec.try_decode_file_format(buf, ctx))
+ }
+
+ fn try_encode_file_format(
+ &self,
+ buf: &mut Vec<u8>,
+ node: Arc<dyn datafusion::datasource::file_format::FileFormatFactory>,
+ ) -> Result<()> {
+ self.try_any(|codec| codec.try_encode_file_format(buf, node.clone()))
+ }
+}
+
#[derive(Debug)]
pub struct BallistaPhysicalExtensionCodec {}
diff --git a/ballista/core/src/serde/scheduler/mod.rs
b/ballista/core/src/serde/scheduler/mod.rs
index 0ced200e..23c9c425 100644
--- a/ballista/core/src/serde/scheduler/mod.rs
+++ b/ballista/core/src/serde/scheduler/mod.rs
@@ -25,6 +25,7 @@ use datafusion::arrow::array::{
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::DataFusionError;
use datafusion::execution::FunctionRegistry;
+use datafusion::logical_expr::planner::ExprPlanner;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::Partitioning;
@@ -299,6 +300,10 @@ pub struct SimpleFunctionRegistry {
}
impl FunctionRegistry for SimpleFunctionRegistry {
+ fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
+ vec![]
+ }
+
fn udfs(&self) -> HashSet<String> {
self.scalar_functions.keys().cloned().collect()
}
diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs
index 45a4f53f..7e88ffaf 100644
--- a/ballista/core/src/utils.rs
+++ b/ballista/core/src/utils.rs
@@ -35,6 +35,7 @@ use datafusion::execution::context::{
QueryPlanner, SessionConfig, SessionContext, SessionState,
};
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
+use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::logical_expr::{DdlStatement, LogicalPlan};
use datafusion::physical_plan::aggregates::AggregateExec;
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
@@ -62,13 +63,14 @@ use tonic::transport::{Channel, Error, Server};
/// Default session builder using the provided configuration
pub fn default_session_builder(config: SessionConfig) -> SessionState {
- SessionState::new_with_config_rt(
- config,
- Arc::new(
+ SessionStateBuilder::new()
+ .with_default_features()
+ .with_config(config)
+ .with_runtime_env(Arc::new(
RuntimeEnv::new(with_object_store_registry(RuntimeConfig::default()))
.unwrap(),
- ),
- )
+ ))
+ .build()
}
/// Stream data to disk in Arrow IPC format
@@ -252,15 +254,16 @@ pub fn create_df_ctx_with_ballista_query_planner<T:
'static + AsLogicalPlan>(
let session_config = SessionConfig::new()
.with_target_partitions(config.default_shuffle_partitions())
.with_information_schema(true);
- let mut session_state = SessionState::new_with_config_rt(
- session_config,
- Arc::new(
+ let session_state = SessionStateBuilder::new()
+ .with_default_features()
+ .with_config(session_config)
+ .with_runtime_env(Arc::new(
RuntimeEnv::new(with_object_store_registry(RuntimeConfig::default()))
.unwrap(),
- ),
- )
- .with_query_planner(planner);
- session_state = session_state.with_session_id(session_id);
+ ))
+ .with_query_planner(planner)
+ .with_session_id(session_id)
+ .build();
// the SessionContext created here is the client side context, but the
session_id is from server side.
SessionContext::new_with_state(session_state)
}
diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml
index 6bebaa87..e0ca6efb 100644
--- a/ballista/executor/Cargo.toml
+++ b/ballista/executor/Cargo.toml
@@ -48,7 +48,6 @@ dashmap = "5.4.0"
datafusion = { workspace = true }
datafusion-proto = { workspace = true }
futures = "0.3"
-hyper = "0.14.4"
log = "0.4"
mimalloc = { version = "0.1", default-features = false, optional = true }
num_cpus = "1.13.0"
diff --git a/ballista/executor/src/collect.rs b/ballista/executor/src/collect.rs
index eb96e314..1d77e719 100644
--- a/ballista/executor/src/collect.rs
+++ b/ballista/executor/src/collect.rs
@@ -67,6 +67,10 @@ impl DisplayAs for CollectExec {
}
impl ExecutionPlan for CollectExec {
+ fn name(&self) -> &str {
+ "CollectExec"
+ }
+
fn as_any(&self) -> &dyn Any {
self
}
diff --git a/ballista/executor/src/executor.rs
b/ballista/executor/src/executor.rs
index ccc7f273..4e83b125 100644
--- a/ballista/executor/src/executor.rs
+++ b/ballista/executor/src/executor.rs
@@ -28,6 +28,8 @@ use ballista_core::serde::scheduler::PartitionId;
use dashmap::DashMap;
use datafusion::execution::context::TaskContext;
use datafusion::execution::runtime_env::RuntimeEnv;
+use datafusion::functions::all_default_functions;
+use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
use futures::future::AbortHandle;
use std::collections::HashMap;
@@ -103,12 +105,22 @@ impl Executor {
concurrent_tasks: usize,
execution_engine: Option<Arc<dyn ExecutionEngine>>,
) -> Self {
+ let scalar_functions = all_default_functions()
+ .into_iter()
+ .map(|f| (f.name().to_string(), f))
+ .collect();
+
+ let aggregate_functions = all_default_aggregate_functions()
+ .into_iter()
+ .map(|f| (f.name().to_string(), f))
+ .collect();
+
Self {
metadata,
work_dir: work_dir.to_owned(),
- // TODO add logic to dynamically load UDF/UDAFs libs from files
- scalar_functions: HashMap::new(),
- aggregate_functions: HashMap::new(),
+ scalar_functions,
+ aggregate_functions,
+ // TODO: set to default window functions when they are moved to
udwf
window_functions: HashMap::new(),
runtime,
runtime_with_data_cache,
@@ -277,6 +289,10 @@ mod test {
}
impl ExecutionPlan for NeverendingOperator {
+ fn name(&self) -> &str {
+ "NeverendingOperator"
+ }
+
fn as_any(&self) -> &dyn Any {
self
}
diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml
index dd878b9a..596db6d1 100644
--- a/ballista/scheduler/Cargo.toml
+++ b/ballista/scheduler/Cargo.toml
@@ -46,20 +46,19 @@ anyhow = "1"
arrow-flight = { workspace = true }
async-recursion = "1.0.0"
async-trait = "0.1.41"
+axum = "0.6.20"
ballista-core = { path = "../core", version = "0.12.0", features = ["s3"] }
base64 = { version = "0.21" }
-clap = { version = "3", features = ["derive", "cargo"] }
+clap = { workspace = true }
configure_me = { workspace = true }
dashmap = "5.4.0"
datafusion = { workspace = true }
datafusion-proto = { workspace = true }
-etcd-client = { version = "0.12", optional = true }
+etcd-client = { version = "0.14", optional = true }
flatbuffers = { version = "23.5.26" }
futures = "0.3"
graphviz-rust = "0.8.0"
-http = "0.2"
-http-body = "0.4"
-hyper = "0.14.4"
+http = "0.2.9"
itertools = "0.12.0"
log = "0.4"
object_store = { workspace = true }
@@ -67,20 +66,20 @@ once_cell = { version = "1.16.0", optional = true }
parking_lot = "0.12"
parse_arg = "0.1.3"
prometheus = { version = "0.13", features = ["process"], optional = true }
-prost = "0.12"
-prost-types = { version = "0.12.0" }
+prost = { workspace = true }
+prost-types = { workspace = true }
rand = "0.8"
serde = { version = "1", features = ["derive"] }
sled_package = { package = "sled", version = "0.34", optional = true }
tokio = { version = "1.0", features = ["full"] }
tokio-stream = { version = "0.1", features = ["net"], optional = true }
tonic = { workspace = true }
-tower = { version = "0.4" }
+# tonic 0.12.2 depends on tower 0.4.7
+tower = { version = "0.4.7", default-features = false, features = ["make",
"util"] }
tracing = { workspace = true }
tracing-appender = { workspace = true }
tracing-subscriber = { workspace = true }
uuid = { version = "1.0", features = ["v4"] }
-warp = "0.3"
[dev-dependencies]
ballista-core = { path = "../core", version = "0.12.0" }
diff --git a/ballista/scheduler/src/api/handlers.rs
b/ballista/scheduler/src/api/handlers.rs
index 463ca217..4d0366ff 100644
--- a/ballista/scheduler/src/api/handlers.rs
+++ b/ballista/scheduler/src/api/handlers.rs
@@ -14,6 +14,11 @@ use crate::scheduler_server::event::QueryStageSchedulerEvent;
use crate::scheduler_server::SchedulerServer;
use crate::state::execution_graph::ExecutionStage;
use crate::state::execution_graph_dot::ExecutionGraphDot;
+use axum::{
+ extract::{Path, State},
+ response::{IntoResponse, Response},
+ Json,
+};
use ballista_core::serde::protobuf::job_status::Status;
use ballista_core::BALLISTA_VERSION;
use datafusion::physical_plan::metrics::{MetricValue, MetricsSet, Time};
@@ -22,10 +27,9 @@ use datafusion_proto::physical_plan::AsExecutionPlan;
use graphviz_rust::cmd::{CommandArg, Format};
use graphviz_rust::exec;
use graphviz_rust::printer::PrinterContext;
-use http::header::CONTENT_TYPE;
-
+use http::{header::CONTENT_TYPE, StatusCode};
+use std::sync::Arc;
use std::time::Duration;
-use warp::Rejection;
#[derive(Debug, serde::Serialize)]
struct SchedulerStateResponse {
@@ -64,22 +68,26 @@ pub struct QueryStageSummary {
pub elapsed_compute: String,
}
-/// Return current scheduler state
-pub(crate) async fn get_scheduler_state<T: AsLogicalPlan, U: AsExecutionPlan>(
- data_server: SchedulerServer<T, U>,
-) -> Result<impl warp::Reply, Rejection> {
+pub async fn get_scheduler_state<
+ T: AsLogicalPlan + Clone + Send + Sync + 'static,
+ U: AsExecutionPlan + Send + Sync + 'static,
+>(
+ State(data_server): State<Arc<SchedulerServer<T, U>>>,
+) -> impl IntoResponse {
let response = SchedulerStateResponse {
started: data_server.start_time,
version: BALLISTA_VERSION,
};
- Ok(warp::reply::json(&response))
+ Json(response)
}
-/// Return list of executors
-pub(crate) async fn get_executors<T: AsLogicalPlan, U: AsExecutionPlan>(
- data_server: SchedulerServer<T, U>,
-) -> Result<impl warp::Reply, Rejection> {
- let state = data_server.state;
+pub async fn get_executors<
+ T: AsLogicalPlan + Clone + Send + Sync + 'static,
+ U: AsExecutionPlan + Send + Sync + 'static,
+>(
+ State(data_server): State<Arc<SchedulerServer<T, U>>>,
+) -> impl IntoResponse {
+ let state = &data_server.state;
let executors: Vec<ExecutorMetaResponse> = state
.executor_manager
.get_executor_state()
@@ -94,21 +102,23 @@ pub(crate) async fn get_executors<T: AsLogicalPlan, U:
AsExecutionPlan>(
})
.collect();
- Ok(warp::reply::json(&executors))
+ Json(executors)
}
-/// Return list of jobs
-pub(crate) async fn get_jobs<T: AsLogicalPlan, U: AsExecutionPlan>(
- data_server: SchedulerServer<T, U>,
-) -> Result<impl warp::Reply, Rejection> {
+pub async fn get_jobs<
+ T: AsLogicalPlan + Clone + Send + Sync + 'static,
+ U: AsExecutionPlan + Send + Sync + 'static,
+>(
+ State(data_server): State<Arc<SchedulerServer<T, U>>>,
+) -> Result<impl IntoResponse, StatusCode> {
// TODO: Display last seen information in UI
- let state = data_server.state;
+ let state = &data_server.state;
let jobs = state
.task_manager
.get_jobs()
.await
- .map_err(|_| warp::reject())?;
+ .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let jobs: Vec<JobResponse> = jobs
.iter()
@@ -157,31 +167,34 @@ pub(crate) async fn get_jobs<T: AsLogicalPlan, U:
AsExecutionPlan>(
})
.collect();
- Ok(warp::reply::json(&jobs))
+ Ok(Json(jobs))
}
-pub(crate) async fn cancel_job<T: AsLogicalPlan, U: AsExecutionPlan>(
- data_server: SchedulerServer<T, U>,
- job_id: String,
-) -> Result<impl warp::Reply, Rejection> {
+pub async fn cancel_job<
+ T: AsLogicalPlan + Clone + Send + Sync + 'static,
+ U: AsExecutionPlan + Send + Sync + 'static,
+>(
+ State(data_server): State<Arc<SchedulerServer<T, U>>>,
+ Path(job_id): Path<String>,
+) -> Result<impl IntoResponse, StatusCode> {
// 404 if job doesn't exist
data_server
.state
.task_manager
.get_job_status(&job_id)
.await
- .map_err(|_| warp::reject())?
- .ok_or_else(warp::reject)?;
+ .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
+ .ok_or(StatusCode::NOT_FOUND)?;
data_server
.query_stage_event_loop
.get_sender()
- .map_err(|_| warp::reject())?
+ .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.post_event(QueryStageSchedulerEvent::JobCancel(job_id))
.await
- .map_err(|_| warp::reject())?;
+ .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
- Ok(warp::reply::json(&CancelJobResponse { cancelled: true }))
+ Ok(Json(CancelJobResponse { cancelled: true }))
}
#[derive(Debug, serde::Serialize)]
@@ -189,69 +202,71 @@ pub struct QueryStagesResponse {
pub stages: Vec<QueryStageSummary>,
}
-/// Get the execution graph for the specified job id
-pub(crate) async fn get_query_stages<T: AsLogicalPlan, U: AsExecutionPlan>(
- data_server: SchedulerServer<T, U>,
- job_id: String,
-) -> Result<impl warp::Reply, Rejection> {
+pub async fn get_query_stages<
+ T: AsLogicalPlan + Clone + Send + Sync + 'static,
+ U: AsExecutionPlan + Send + Sync + 'static,
+>(
+ State(data_server): State<Arc<SchedulerServer<T, U>>>,
+ Path(job_id): Path<String>,
+) -> Result<impl IntoResponse, StatusCode> {
if let Some(graph) = data_server
.state
.task_manager
.get_job_execution_graph(&job_id)
.await
- .map_err(|_| warp::reject())?
+ .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
- Ok(warp::reply::json(&QueryStagesResponse {
- stages: graph
- .as_ref()
- .stages()
- .iter()
- .map(|(id, stage)| {
- let mut summary = QueryStageSummary {
- stage_id: id.to_string(),
- stage_status: stage.variant_name().to_string(),
- input_rows: 0,
- output_rows: 0,
- elapsed_compute: "".to_string(),
- };
- match stage {
- ExecutionStage::Running(running_stage) => {
- summary.input_rows = running_stage
- .stage_metrics
- .as_ref()
- .map(|m| get_combined_count(m.as_slice(),
"input_rows"))
- .unwrap_or(0);
- summary.output_rows = running_stage
- .stage_metrics
- .as_ref()
- .map(|m| get_combined_count(m.as_slice(),
"output_rows"))
- .unwrap_or(0);
- summary.elapsed_compute = running_stage
- .stage_metrics
- .as_ref()
- .map(|m|
get_elapsed_compute_nanos(m.as_slice()))
- .unwrap_or_default();
- }
- ExecutionStage::Successful(completed_stage) => {
- summary.input_rows = get_combined_count(
- &completed_stage.stage_metrics,
- "input_rows",
- );
- summary.output_rows = get_combined_count(
- &completed_stage.stage_metrics,
- "output_rows",
- );
- summary.elapsed_compute =
-
get_elapsed_compute_nanos(&completed_stage.stage_metrics);
- }
- _ => {}
+ let stages = graph
+ .as_ref()
+ .stages()
+ .iter()
+ .map(|(id, stage)| {
+ let mut summary = QueryStageSummary {
+ stage_id: id.to_string(),
+ stage_status: stage.variant_name().to_string(),
+ input_rows: 0,
+ output_rows: 0,
+ elapsed_compute: "".to_string(),
+ };
+ match stage {
+ ExecutionStage::Running(running_stage) => {
+ summary.input_rows = running_stage
+ .stage_metrics
+ .as_ref()
+ .map(|m| get_combined_count(m.as_slice(),
"input_rows"))
+ .unwrap_or(0);
+ summary.output_rows = running_stage
+ .stage_metrics
+ .as_ref()
+ .map(|m| get_combined_count(m.as_slice(),
"output_rows"))
+ .unwrap_or(0);
+ summary.elapsed_compute = running_stage
+ .stage_metrics
+ .as_ref()
+ .map(|m| get_elapsed_compute_nanos(m.as_slice()))
+ .unwrap_or_default();
+ }
+ ExecutionStage::Successful(completed_stage) => {
+ summary.input_rows = get_combined_count(
+ &completed_stage.stage_metrics,
+ "input_rows",
+ );
+ summary.output_rows = get_combined_count(
+ &completed_stage.stage_metrics,
+ "output_rows",
+ );
+ summary.elapsed_compute =
+
get_elapsed_compute_nanos(&completed_stage.stage_metrics);
}
- summary
- })
- .collect(),
- }))
+ _ => {}
+ }
+ summary
+ })
+ .collect();
+
+ Ok(Json(QueryStagesResponse { stages }))
} else {
- Ok(warp::reply::json(&QueryStagesResponse { stages: vec![] }))
+ Ok(Json(QueryStagesResponse { stages: vec![] }))
}
}
@@ -286,78 +301,96 @@ fn get_combined_count(metrics: &[MetricsSet], name: &str)
-> usize {
.sum()
}
-/// Generate a dot graph for the specified job id and return as plain text
-pub(crate) async fn get_job_dot_graph<T: AsLogicalPlan, U: AsExecutionPlan>(
- data_server: SchedulerServer<T, U>,
- job_id: String,
-) -> Result<String, Rejection> {
+pub async fn get_job_dot_graph<
+ T: AsLogicalPlan + Clone + Send + Sync + 'static,
+ U: AsExecutionPlan + Send + Sync + 'static,
+>(
+ State(data_server): State<Arc<SchedulerServer<T, U>>>,
+ Path(job_id): Path<String>,
+) -> Result<String, StatusCode> {
if let Some(graph) = data_server
.state
.task_manager
.get_job_execution_graph(&job_id)
.await
- .map_err(|_| warp::reject())?
+ .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
- ExecutionGraphDot::generate(graph.as_ref()).map_err(|_| warp::reject())
+ ExecutionGraphDot::generate(graph.as_ref())
+ .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} else {
Ok("Not Found".to_string())
}
}
-/// Generate a dot graph for the specified job id and query stage and return
as plain text
-pub(crate) async fn get_query_stage_dot_graph<T: AsLogicalPlan, U:
AsExecutionPlan>(
- data_server: SchedulerServer<T, U>,
- job_id: String,
- stage_id: usize,
-) -> Result<String, Rejection> {
+pub async fn get_query_stage_dot_graph<
+ T: AsLogicalPlan + Clone + Send + Sync + 'static,
+ U: AsExecutionPlan + Send + Sync + 'static,
+>(
+ State(data_server): State<Arc<SchedulerServer<T, U>>>,
+ Path((job_id, stage_id)): Path<(String, usize)>,
+) -> Result<impl IntoResponse, StatusCode> {
if let Some(graph) = data_server
.state
.task_manager
.get_job_execution_graph(&job_id)
.await
- .map_err(|_| warp::reject())?
+ .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
ExecutionGraphDot::generate_for_query_stage(graph.as_ref(), stage_id)
- .map_err(|_| warp::reject())
+ .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} else {
Ok("Not Found".to_string())
}
}
-/// Generate an SVG graph for the specified job id and return it as plain text
-pub(crate) async fn get_job_svg_graph<T: AsLogicalPlan, U: AsExecutionPlan>(
- data_server: SchedulerServer<T, U>,
- job_id: String,
-) -> Result<String, Rejection> {
- let dot = get_job_dot_graph(data_server, job_id).await;
- match dot {
- Ok(dot) => {
- let graph = graphviz_rust::parse(&dot);
- if let Ok(graph) = graph {
- exec(
- graph,
- &mut PrinterContext::default(),
- vec![CommandArg::Format(Format::Svg)],
- )
- .map(|bytes| String::from_utf8_lossy(&bytes).to_string())
- .map_err(|_| warp::reject())
- } else {
- Ok("Cannot parse graph".to_string())
- }
+pub async fn get_job_svg_graph<
+ T: AsLogicalPlan + Clone + Send + Sync + 'static,
+ U: AsExecutionPlan + Send + Sync + 'static,
+>(
+ State(data_server): State<Arc<SchedulerServer<T, U>>>,
+ Path(job_id): Path<String>,
+) -> Result<impl IntoResponse, StatusCode> {
+ let dot = get_job_dot_graph(State(data_server.clone()),
Path(job_id)).await?;
+ match graphviz_rust::parse(&dot) {
+ Ok(graph) => {
+ let result = exec(
+ graph,
+ &mut PrinterContext::default(),
+ vec![CommandArg::Format(Format::Svg)],
+ )
+ .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
+
+ let svg = String::from_utf8_lossy(&result).to_string();
+ Ok(Response::builder()
+ .header(CONTENT_TYPE, "image/svg+xml")
+ .body(svg)
+ .unwrap())
}
- _ => Ok("Not Found".to_string()),
+ Err(_) => Ok(Response::builder()
+ .status(StatusCode::BAD_REQUEST)
+ .body("Cannot parse graph".to_string())
+ .unwrap()),
}
}
-pub(crate) async fn get_scheduler_metrics<T: AsLogicalPlan, U:
AsExecutionPlan>(
- data_server: SchedulerServer<T, U>,
-) -> Result<impl warp::Reply, Rejection> {
- Ok(data_server
- .metrics_collector()
- .gather_metrics()
- .map_err(|_| warp::reject())?
- .map(|(data, content_type)| {
- warp::reply::with_header(data, CONTENT_TYPE, content_type)
- })
- .unwrap_or_else(|| warp::reply::with_header(vec![], CONTENT_TYPE,
"text/html")))
+pub async fn get_scheduler_metrics<
+ T: AsLogicalPlan + Clone + Send + Sync + 'static,
+ U: AsExecutionPlan + Send + Sync + 'static,
+>(
+ State(data_server): State<Arc<SchedulerServer<T, U>>>,
+) -> impl IntoResponse {
+ match data_server.metrics_collector().gather_metrics() {
+ Ok(Some((data, content_type))) => Response::builder()
+ .header(CONTENT_TYPE, content_type)
+ .body(axum::body::Body::from(data))
+ .unwrap(),
+ Ok(None) => Response::builder()
+ .status(StatusCode::NO_CONTENT)
+ .body(axum::body::Body::empty())
+ .unwrap(),
+ Err(_) => Response::builder()
+ .status(StatusCode::INTERNAL_SERVER_ERROR)
+ .body(axum::body::Body::empty())
+ .unwrap(),
+ }
}
diff --git a/ballista/scheduler/src/api/mod.rs
b/ballista/scheduler/src/api/mod.rs
index 8f5555d0..c33d5157 100644
--- a/ballista/scheduler/src/api/mod.rs
+++ b/ballista/scheduler/src/api/mod.rs
@@ -13,126 +13,39 @@
mod handlers;
use crate::scheduler_server::SchedulerServer;
-use anyhow::Result;
+use axum::routing::patch;
+use axum::{routing::get, Router};
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
-use std::{
- pin::Pin,
- task::{Context as TaskContext, Poll},
-};
-use warp::filters::BoxedFilter;
-use warp::{Buf, Filter, Reply};
-
-pub enum EitherBody<A, B> {
- Left(A),
- Right(B),
-}
-
-pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
-pub type HttpBody = dyn http_body::Body<Data = dyn Buf, Error = Error> +
'static;
-
-impl<A, B> http_body::Body for EitherBody<A, B>
-where
- A: http_body::Body + Send + Unpin,
- B: http_body::Body<Data = A::Data> + Send + Unpin,
- A::Error: Into<Error>,
- B::Error: Into<Error>,
-{
- type Data = A::Data;
- type Error = Error;
-
- fn poll_data(
- self: Pin<&mut Self>,
- cx: &mut TaskContext<'_>,
- ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
- match self.get_mut() {
- EitherBody::Left(b) =>
Pin::new(b).poll_data(cx).map(map_option_err),
- EitherBody::Right(b) =>
Pin::new(b).poll_data(cx).map(map_option_err),
- }
- }
-
- fn poll_trailers(
- self: Pin<&mut Self>,
- cx: &mut TaskContext<'_>,
- ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
- match self.get_mut() {
- EitherBody::Left(b) =>
Pin::new(b).poll_trailers(cx).map_err(Into::into),
- EitherBody::Right(b) =>
Pin::new(b).poll_trailers(cx).map_err(Into::into),
- }
- }
-
- fn is_end_stream(&self) -> bool {
- match self {
- EitherBody::Left(b) => b.is_end_stream(),
- EitherBody::Right(b) => b.is_end_stream(),
- }
- }
-}
-
-fn map_option_err<T, U: Into<Error>>(
- err: Option<Result<T, U>>,
-) -> Option<Result<T, Error>> {
- err.map(|e| e.map_err(Into::into))
-}
-
-fn with_data_server<T: AsLogicalPlan + Clone, U: 'static + AsExecutionPlan>(
- db: SchedulerServer<T, U>,
-) -> impl Filter<Extract = (SchedulerServer<T, U>,), Error =
std::convert::Infallible> + Clone
-{
- warp::any().map(move || db.clone())
-}
-
-pub fn get_routes<T: AsLogicalPlan + Clone, U: 'static + AsExecutionPlan>(
- scheduler_server: SchedulerServer<T, U>,
-) -> BoxedFilter<(impl Reply,)> {
- let route_scheduler_state = warp::path!("api" / "state")
- .and(with_data_server(scheduler_server.clone()))
- .and_then(handlers::get_scheduler_state);
-
- let route_executors = warp::path!("api" / "executors")
- .and(with_data_server(scheduler_server.clone()))
- .and_then(handlers::get_executors);
-
- let route_jobs = warp::path!("api" / "jobs")
- .and(with_data_server(scheduler_server.clone()))
- .and_then(|data_server| handlers::get_jobs(data_server));
-
- let route_cancel_job = warp::path!("api" / "job" / String)
- .and(warp::patch())
- .and(with_data_server(scheduler_server.clone()))
- .and_then(|job_id, data_server| handlers::cancel_job(data_server,
job_id));
-
- let route_query_stages = warp::path!("api" / "job" / String / "stages")
- .and(with_data_server(scheduler_server.clone()))
- .and_then(|job_id, data_server|
handlers::get_query_stages(data_server, job_id));
-
- let route_job_dot = warp::path!("api" / "job" / String / "dot")
- .and(with_data_server(scheduler_server.clone()))
- .and_then(|job_id, data_server|
handlers::get_job_dot_graph(data_server, job_id));
-
- let route_query_stage_dot =
- warp::path!("api" / "job" / String / "stage" / usize / "dot")
- .and(with_data_server(scheduler_server.clone()))
- .and_then(|job_id, stage_id, data_server| {
- handlers::get_query_stage_dot_graph(data_server, job_id,
stage_id)
- });
-
- let route_job_dot_svg = warp::path!("api" / "job" / String / "dot_svg")
- .and(with_data_server(scheduler_server.clone()))
- .and_then(|job_id, data_server|
handlers::get_job_svg_graph(data_server, job_id));
-
- let route_scheduler_metrics = warp::path!("api" / "metrics")
- .and(with_data_server(scheduler_server))
- .and_then(|data_server| handlers::get_scheduler_metrics(data_server));
-
- let routes = route_scheduler_state
- .or(route_executors)
- .or(route_jobs)
- .or(route_cancel_job)
- .or(route_query_stages)
- .or(route_job_dot)
- .or(route_query_stage_dot)
- .or(route_job_dot_svg)
- .or(route_scheduler_metrics);
- routes.boxed()
+use std::sync::Arc;
+
+pub fn get_routes<
+ T: AsLogicalPlan + Clone + Send + Sync + 'static,
+ U: AsExecutionPlan + Send + Sync + 'static,
+>(
+ scheduler_server: Arc<SchedulerServer<T, U>>,
+) -> Router {
+ Router::new()
+ .route("/api/state", get(handlers::get_scheduler_state::<T, U>))
+ .route("/api/executors", get(handlers::get_executors::<T, U>))
+ .route("/api/jobs", get(handlers::get_jobs::<T, U>))
+ .route("/api/job/:job_id", patch(handlers::cancel_job::<T, U>))
+ .route(
+ "/api/job/:job_id/stages",
+ get(handlers::get_query_stages::<T, U>),
+ )
+ .route(
+ "/api/job/:job_id/dot",
+ get(handlers::get_job_dot_graph::<T, U>),
+ )
+ .route(
+ "/api/job/:job_id/stage/:stage_id/dot",
+ get(handlers::get_query_stage_dot_graph::<T, U>),
+ )
+ .route(
+ "/api/job/:job_id/dot_svg",
+ get(handlers::get_job_svg_graph::<T, U>),
+ )
+ .route("/api/metrics", get(handlers::get_scheduler_metrics::<T, U>))
+ .with_state(scheduler_server)
}
diff --git a/ballista/scheduler/src/bin/main.rs
b/ballista/scheduler/src/bin/main.rs
index ee9364c7..d2e2c9ce 100644
--- a/ballista/scheduler/src/bin/main.rs
+++ b/ballista/scheduler/src/bin/main.rs
@@ -47,8 +47,17 @@ mod config {
));
}
-#[tokio::main]
-async fn main() -> Result<()> {
+fn main() -> Result<()> {
+ let runtime = tokio::runtime::Builder::new_multi_thread()
+ .enable_io()
+ .enable_time()
+ .thread_stack_size(32 * 1024 * 1024) // 32MB
+ .build()
+ .unwrap();
+
+ runtime.block_on(inner())
+}
+async fn inner() -> Result<()> {
// parse options
let (opt, _remaining_args) =
Config::including_optional_config_files(&["/etc/ballista/scheduler.toml"])
diff --git a/ballista/scheduler/src/cluster/mod.rs
b/ballista/scheduler/src/cluster/mod.rs
index 8313f033..b7489a25 100644
--- a/ballista/scheduler/src/cluster/mod.rs
+++ b/ballista/scheduler/src/cluster/mod.rs
@@ -20,7 +20,7 @@ use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
-use clap::ArgEnum;
+use clap::ValueEnum;
use datafusion::common::tree_node::TreeNode;
use datafusion::common::tree_node::TreeNodeRecursion;
use datafusion::datasource::listing::PartitionedFile;
@@ -65,7 +65,7 @@ pub mod test_util;
// an enum used to configure the backend
// needs to be visible to code generated by configure_me
-#[derive(Debug, Clone, ArgEnum, serde::Deserialize, PartialEq, Eq)]
+#[derive(Debug, Clone, ValueEnum, serde::Deserialize, PartialEq, Eq)]
pub enum ClusterStorage {
Etcd,
Memory,
@@ -76,7 +76,7 @@ impl std::str::FromStr for ClusterStorage {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
- ArgEnum::from_str(s, true)
+ ValueEnum::from_str(s, true)
}
}
@@ -764,7 +764,10 @@ mod test {
};
use crate::state::execution_graph::ExecutionGraph;
use crate::state::task_manager::JobInfoCache;
- use crate::test_utils::{mock_completed_task,
test_aggregation_plan_with_job_id};
+ use crate::test_utils::{
+ mock_completed_task, revive_graph_and_complete_next_stage,
+ test_aggregation_plan_with_job_id,
+ };
#[tokio::test]
async fn test_bind_task_bias() -> Result<()> {
@@ -1008,10 +1011,11 @@ mod test {
async fn mock_graph(
job_id: &str,
- num_partition: usize,
+ num_target_partitions: usize,
num_pending_task: usize,
) -> Result<ExecutionGraph> {
- let mut graph = test_aggregation_plan_with_job_id(num_partition,
job_id).await;
+ let mut graph =
+ test_aggregation_plan_with_job_id(num_target_partitions,
job_id).await;
let executor = ExecutorMetadata {
id: "executor_0".to_string(),
host: "localhost".to_string(),
@@ -1020,14 +1024,10 @@ mod test {
specification: ExecutorSpecification { task_slots: 32 },
};
- if let Some(task) = graph.pop_next_task(&executor.id)? {
- let task_status = mock_completed_task(task, &executor.id);
- graph.update_task_status(&executor, vec![task_status], 1, 1)?;
- }
-
- graph.revive();
+ // complete first stage
+ revive_graph_and_complete_next_stage(&mut graph)?;
- for _i in 0..num_partition - num_pending_task {
+ for _ in 0..num_target_partitions - num_pending_task {
if let Some(task) = graph.pop_next_task(&executor.id)? {
let task_status = mock_completed_task(task, &executor.id);
graph.update_task_status(&executor, vec![task_status], 1, 1)?;
diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs
index d15e928c..82280911 100644
--- a/ballista/scheduler/src/config.rs
+++ b/ballista/scheduler/src/config.rs
@@ -19,7 +19,7 @@
//! Ballista scheduler specific configuration
use ballista_core::config::TaskSchedulingPolicy;
-use clap::ArgEnum;
+use clap::ValueEnum;
use std::fmt;
/// Configurations for the ballista scheduler of scheduling jobs and tasks
@@ -189,7 +189,7 @@ pub enum ClusterStorageConfig {
/// Policy of distributing tasks to available executor slots
///
/// It needs to be visible to code generated by configure_me
-#[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)]
+#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)]
pub enum TaskDistribution {
/// Eagerly assign tasks to executor slots. This will assign as many task
slots per executor
/// as are currently available
@@ -208,7 +208,7 @@ impl std::str::FromStr for TaskDistribution {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
- ArgEnum::from_str(s, true)
+ ValueEnum::from_str(s, true)
}
}
diff --git a/ballista/scheduler/src/planner.rs
b/ballista/scheduler/src/planner.rs
index 3da9f339..0e18a062 100644
--- a/ballista/scheduler/src/planner.rs
+++ b/ballista/scheduler/src/planner.rs
@@ -592,6 +592,8 @@ order by
Ok(())
}
+ #[ignore]
+ // enable when upgrading Datafusion, a bug is fixed with
https://github.com/apache/datafusion/pull/11926/
#[tokio::test]
async fn roundtrip_serde_aggregate() -> Result<(), BallistaError> {
let ctx = datafusion_test_context("testdata").await?;
diff --git a/ballista/scheduler/src/scheduler_process.rs
b/ballista/scheduler/src/scheduler_process.rs
index 6bcaaec5..1f7f7ac3 100644
--- a/ballista/scheduler/src/scheduler_process.rs
+++ b/ballista/scheduler/src/scheduler_process.rs
@@ -15,26 +15,18 @@
// specific language governing permissions and limitations
// under the License.
-use anyhow::{Context, Result};
+use anyhow::{Error, Result};
#[cfg(feature = "flight-sql")]
use arrow_flight::flight_service_server::FlightServiceServer;
-use futures::future::{self, Either, TryFutureExt};
-use hyper::{server::conn::AddrStream, service::make_service_fn, Server};
-use log::info;
-use std::convert::Infallible;
-use std::net::SocketAddr;
-use std::sync::Arc;
-use tonic::transport::server::Connected;
-use tower::Service;
-
-use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
-
use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer;
use ballista_core::serde::BallistaCodec;
use ballista_core::utils::create_grpc_server;
use ballista_core::BALLISTA_VERSION;
+use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
+use log::info;
+use std::{net::SocketAddr, sync::Arc};
-use crate::api::{get_routes, EitherBody, Error};
+use crate::api::get_routes;
use crate::cluster::BallistaCluster;
use crate::config::SchedulerConfig;
use crate::flight_sql::FlightSqlServiceImpl;
@@ -70,58 +62,31 @@ pub async fn start_server(
scheduler_server.init().await?;
- Server::bind(&addr)
- .serve(make_service_fn(move |request: &AddrStream| {
- let config = &scheduler_server.state.config;
- let scheduler_grpc_server =
- SchedulerGrpcServer::new(scheduler_server.clone())
- .max_encoding_message_size(
- config.grpc_server_max_encoding_message_size as usize,
- )
- .max_decoding_message_size(
- config.grpc_server_max_decoding_message_size as usize,
- );
-
- let keda_scaler =
ExternalScalerServer::new(scheduler_server.clone());
-
- let tonic_builder = create_grpc_server()
- .add_service(scheduler_grpc_server)
- .add_service(keda_scaler);
+ let config = &scheduler_server.state.config;
+ let scheduler_grpc_server =
SchedulerGrpcServer::new(scheduler_server.clone())
+
.max_encoding_message_size(config.grpc_server_max_encoding_message_size as
usize)
+
.max_decoding_message_size(config.grpc_server_max_decoding_message_size as
usize);
- #[cfg(feature = "flight-sql")]
- let tonic_builder =
tonic_builder.add_service(FlightServiceServer::new(
- FlightSqlServiceImpl::new(scheduler_server.clone()),
- ));
+ let keda_scaler = ExternalScalerServer::new(scheduler_server.clone());
- let mut tonic = tonic_builder.into_service();
+ let tonic_builder = create_grpc_server()
+ .add_service(scheduler_grpc_server)
+ .add_service(keda_scaler);
- let mut warp = warp::service(get_routes(scheduler_server.clone()));
+ #[cfg(feature = "flight-sql")]
+ let tonic_builder = tonic_builder.add_service(FlightServiceServer::new(
+ FlightSqlServiceImpl::new(scheduler_server.clone()),
+ ));
- let connect_info = request.connect_info();
- future::ok::<_, Infallible>(tower::service_fn(
- move |req: hyper::Request<hyper::Body>| {
- // Set the connect info from hyper to tonic
- let (mut parts, body) = req.into_parts();
- parts.extensions.insert(connect_info.clone());
- let req = http::Request::from_parts(parts, body);
+ let tonic = tonic_builder.into_service().into_router();
- if req.uri().path().starts_with("/api") {
- return Either::Left(
- warp.call(req)
- .map_ok(|res| res.map(EitherBody::Left))
- .map_err(Error::from),
- );
- }
+ let axum = get_routes(Arc::new(scheduler_server));
+ let merged = axum
+ .merge(tonic)
+ .into_make_service_with_connect_info::<SocketAddr>();
- Either::Right(
- tonic
- .call(req)
- .map_ok(|res| res.map(EitherBody::Right))
- .map_err(Error::from),
- )
- },
- ))
- }))
+ axum::Server::bind(&addr)
+ .serve(merged)
.await
- .context("Could not start grpc server")
+ .map_err(Error::from)
}
diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs
b/ballista/scheduler/src/scheduler_server/grpc.rs
index 2d759fb7..6992bf75 100644
--- a/ballista/scheduler/src/scheduler_server/grpc.rs
+++ b/ballista/scheduler/src/scheduler_server/grpc.rs
@@ -15,10 +15,12 @@
// specific language governing permissions and limitations
// under the License.
+use axum::extract::ConnectInfo;
use ballista_core::config::{BallistaConfig, BALLISTA_JOB_NAME};
use ballista_core::serde::protobuf::execute_query_params::{OptionalSessionId,
Query};
use std::collections::HashMap;
use std::convert::TryInto;
+use std::net::SocketAddr;
use ballista_core::serde::protobuf::executor_registration::OptionalHost;
use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc;
@@ -70,7 +72,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerGrpc
"Bad request because poll work is not supported for push-based
task scheduling",
));
}
- let remote_addr = request.remote_addr();
+ let remote_addr = extract_connect_info(&request);
if let PollWorkParams {
metadata: Some(metadata),
num_free_slots,
@@ -155,7 +157,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerGrpc
&self,
request: Request<RegisterExecutorParams>,
) -> Result<Response<RegisterExecutorResult>, Status> {
- let remote_addr = request.remote_addr();
+ let remote_addr = extract_connect_info(&request);
if let RegisterExecutorParams {
metadata: Some(metadata),
} = request.into_inner()
@@ -191,7 +193,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerGrpc
&self,
request: Request<HeartBeatParams>,
) -> Result<Response<HeartBeatResult>, Status> {
- let remote_addr = request.remote_addr();
+ let remote_addr = extract_connect_info(&request);
let HeartBeatParams {
executor_id,
metrics,
@@ -634,6 +636,13 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerGrpc
}
}
+fn extract_connect_info<T>(request: &Request<T>) ->
Option<ConnectInfo<SocketAddr>> {
+ request
+ .extensions()
+ .get::<ConnectInfo<SocketAddr>>()
+ .cloned()
+}
+
#[cfg(all(test, feature = "sled"))]
mod test {
use std::sync::Arc;
diff --git a/ballista/scheduler/src/scheduler_server/mod.rs
b/ballista/scheduler/src/scheduler_server/mod.rs
index e6525f18..c2bf657b 100644
--- a/ballista/scheduler/src/scheduler_server/mod.rs
+++ b/ballista/scheduler/src/scheduler_server/mod.rs
@@ -345,7 +345,7 @@ mod test {
use datafusion::functions_aggregate::sum::sum;
use datafusion::logical_expr::{col, LogicalPlan};
- use datafusion::test_util::scan_empty;
+ use datafusion::test_util::scan_empty_with_partitions;
use datafusion_proto::protobuf::LogicalPlanNode;
use datafusion_proto::protobuf::PhysicalPlanNode;
@@ -700,7 +700,9 @@ mod test {
Field::new("gmv", DataType::UInt64, false),
]);
- scan_empty(None, &schema, Some(vec![0, 1]))
+ // partitions need to be > 1 for the datafusion's optimizer to insert
a repartition node
+ // behavior changed with:
https://github.com/apache/datafusion/pull/11875
+ scan_empty_with_partitions(None, &schema, Some(vec![0, 1]), 2)
.unwrap()
.aggregate(vec![col("id")], vec![sum(col("gmv"))])
.unwrap()
diff --git a/ballista/scheduler/src/state/execution_graph.rs
b/ballista/scheduler/src/state/execution_graph.rs
index 9ee95d67..333545d3 100644
--- a/ballista/scheduler/src/state/execution_graph.rs
+++ b/ballista/scheduler/src/state/execution_graph.rs
@@ -1694,7 +1694,9 @@ mod test {
use crate::state::execution_graph::ExecutionGraph;
use crate::test_utils::{
- mock_completed_task, mock_executor, mock_failed_task,
test_aggregation_plan,
+ mock_completed_task, mock_executor, mock_failed_task,
+ revive_graph_and_complete_next_stage,
+ revive_graph_and_complete_next_stage_with_executor,
test_aggregation_plan,
test_coalesce_plan, test_join_plan, test_two_aggregations_plan,
test_union_all_plan, test_union_plan,
};
@@ -1793,19 +1795,13 @@ mod test {
join_graph.revive();
assert_eq!(join_graph.stage_count(), 4);
- assert_eq!(join_graph.available_tasks(), 2);
+ assert_eq!(join_graph.available_tasks(), 4);
// Complete the first stage
- if let Some(task) = join_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- join_graph.update_task_status(&executor1, vec![task_status], 1,
1)?;
- }
+ revive_graph_and_complete_next_stage_with_executor(&mut join_graph,
&executor1)?;
// Complete the second stage
- if let Some(task) = join_graph.pop_next_task(&executor2.id)? {
- let task_status = mock_completed_task(task, &executor2.id);
- join_graph.update_task_status(&executor2, vec![task_status], 1,
1)?;
- }
+ revive_graph_and_complete_next_stage_with_executor(&mut join_graph,
&executor2)?;
join_graph.revive();
// There are 4 tasks pending schedule for the 3rd stage
@@ -1823,7 +1819,7 @@ mod test {
// Two stages were reset, 1 Running stage rollback to Unresolved and 1
Completed stage move to Running
assert_eq!(reset.0.len(), 2);
- assert_eq!(join_graph.available_tasks(), 1);
+ assert_eq!(join_graph.available_tasks(), 2);
drain_tasks(&mut join_graph)?;
assert!(join_graph.is_successful(), "Failed to complete join plan");
@@ -1844,19 +1840,19 @@ mod test {
join_graph.revive();
assert_eq!(join_graph.stage_count(), 4);
- assert_eq!(join_graph.available_tasks(), 2);
+ assert_eq!(join_graph.available_tasks(), 4);
// Complete the first stage
- if let Some(task) = join_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- join_graph.update_task_status(&executor1, vec![task_status], 1,
1)?;
- }
+ assert_eq!(revive_graph_and_complete_next_stage(&mut join_graph)?, 2);
// Complete the second stage
- if let Some(task) = join_graph.pop_next_task(&executor2.id)? {
- let task_status = mock_completed_task(task, &executor2.id);
- join_graph.update_task_status(&executor2, vec![task_status], 1,
1)?;
- }
+ assert_eq!(
+ revive_graph_and_complete_next_stage_with_executor(
+ &mut join_graph,
+ &executor2
+ )?,
+ 2
+ );
// There are 0 tasks pending schedule now
assert_eq!(join_graph.available_tasks(), 0);
@@ -1865,7 +1861,7 @@ mod test {
// Two stages were reset, 1 Resolved stage rollback to Unresolved and
1 Completed stage move to Running
assert_eq!(reset.0.len(), 2);
- assert_eq!(join_graph.available_tasks(), 1);
+ assert_eq!(join_graph.available_tasks(), 2);
drain_tasks(&mut join_graph)?;
assert!(join_graph.is_successful(), "Failed to complete join plan");
@@ -1886,13 +1882,10 @@ mod test {
agg_graph.revive();
assert_eq!(agg_graph.stage_count(), 2);
- assert_eq!(agg_graph.available_tasks(), 1);
+ assert_eq!(agg_graph.available_tasks(), 2);
// Complete the first stage
- if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
- }
+ revive_graph_and_complete_next_stage_with_executor(&mut agg_graph,
&executor1)?;
// 1st task in the second stage
if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
@@ -1920,12 +1913,12 @@ mod test {
// Two stages were reset, 1 Running stage rollback to Unresolved and 1
Completed stage move to Running
assert_eq!(reset.0.len(), 2);
- assert_eq!(agg_graph.available_tasks(), 1);
+ assert_eq!(agg_graph.available_tasks(), 2);
// Call the reset again
let reset = agg_graph.reset_stages_on_lost_executor(&executor1.id)?;
assert_eq!(reset.0.len(), 0);
- assert_eq!(agg_graph.available_tasks(), 1);
+ assert_eq!(agg_graph.available_tasks(), 2);
drain_tasks(&mut agg_graph)?;
assert!(agg_graph.is_successful(), "Failed to complete agg plan");
@@ -1935,24 +1928,20 @@ mod test {
#[tokio::test]
async fn test_do_not_retry_killed_task() -> Result<()> {
- let executor1 = mock_executor("executor-id1".to_string());
- let executor2 = mock_executor("executor-id2".to_string());
+ let executor = mock_executor("executor-id-123".to_string());
let mut agg_graph = test_aggregation_plan(4).await;
// Call revive to move the leaf Resolved stages to Running
agg_graph.revive();
// Complete the first stage
- if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
- }
+ revive_graph_and_complete_next_stage(&mut agg_graph)?;
// 1st task in the second stage
- let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
- let task_status1 = mock_completed_task(task1, &executor2.id);
+ let task1 = agg_graph.pop_next_task(&executor.id)?.unwrap();
+ let task_status1 = mock_completed_task(task1, &executor.id);
// 2rd task in the second stage
- let task2 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
+ let task2 = agg_graph.pop_next_task(&executor.id)?.unwrap();
let task_status2 = mock_failed_task(
task2,
FailedTask {
@@ -1964,7 +1953,7 @@ mod test {
);
agg_graph.update_task_status(
- &executor2,
+ &executor,
vec![task_status1, task_status2],
4,
4,
@@ -1983,24 +1972,20 @@ mod test {
#[tokio::test]
async fn test_max_task_failed_count() -> Result<()> {
- let executor1 = mock_executor("executor-id1".to_string());
- let executor2 = mock_executor("executor-id2".to_string());
+ let executor = mock_executor("executor-id2".to_string());
let mut agg_graph = test_aggregation_plan(2).await;
// Call revive to move the leaf Resolved stages to Running
agg_graph.revive();
// Complete the first stage
- if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
- }
+ revive_graph_and_complete_next_stage(&mut agg_graph)?;
// 1st task in the second stage
- let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
- let task_status1 = mock_completed_task(task1, &executor2.id);
+ let task1 = agg_graph.pop_next_task(&executor.id)?.unwrap();
+ let task_status1 = mock_completed_task(task1, &executor.id);
// 2rd task in the second stage, failed due to IOError
- let task2 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
+ let task2 = agg_graph.pop_next_task(&executor.id)?.unwrap();
let task_status2 = mock_failed_task(
task2.clone(),
FailedTask {
@@ -2012,7 +1997,7 @@ mod test {
);
agg_graph.update_task_status(
- &executor2,
+ &executor,
vec![task_status1, task_status2],
4,
4,
@@ -2023,7 +2008,7 @@ mod test {
let mut last_attempt = 0;
// 2rd task's attempts
for attempt in 1..5 {
- if let Some(task2_attempt) =
agg_graph.pop_next_task(&executor2.id)? {
+ if let Some(task2_attempt) =
agg_graph.pop_next_task(&executor.id)? {
assert_eq!(
task2_attempt.partition.partition_id,
task2.partition.partition_id
@@ -2041,7 +2026,7 @@ mod test {
)),
},
);
- agg_graph.update_task_status(&executor2, vec![task_status], 4,
4)?;
+ agg_graph.update_task_status(&executor, vec![task_status], 4,
4)?;
}
}
@@ -2075,10 +2060,7 @@ mod test {
agg_graph.revive();
// Complete the Stage 1
- if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
- }
+ revive_graph_and_complete_next_stage_with_executor(&mut agg_graph,
&executor1)?;
// 1st task in the Stage 2
if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
@@ -2103,13 +2085,10 @@ mod test {
// Two stages were reset, Stage 2 rollback to Unresolved and Stage 1
move to Running
assert_eq!(reset.0.len(), 2);
- assert_eq!(agg_graph.available_tasks(), 1);
+ assert_eq!(agg_graph.available_tasks(), 2);
// Complete the Stage 1 again
- if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
- }
+ revive_graph_and_complete_next_stage_with_executor(&mut agg_graph,
&executor1)?;
// Stage 2 move to Running
agg_graph.revive();
@@ -2148,10 +2127,7 @@ mod test {
agg_graph.revive();
// Complete the Stage 1
- if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
- }
+ revive_graph_and_complete_next_stage(&mut agg_graph)?;
// 1st task in the Stage 2
let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
@@ -2198,7 +2174,7 @@ mod test {
let running_stage = agg_graph.running_stages();
assert_eq!(running_stage.len(), 1);
assert_eq!(running_stage[0], 1);
- assert_eq!(agg_graph.available_tasks(), 1);
+ assert_eq!(agg_graph.available_tasks(), 2);
drain_tasks(&mut agg_graph)?;
assert!(agg_graph.is_successful(), "Failed to complete agg plan");
@@ -2216,10 +2192,7 @@ mod test {
assert_eq!(agg_graph.stage_count(), 3);
// Complete the Stage 1
- if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
- }
+ revive_graph_and_complete_next_stage(&mut agg_graph)?;
// Complete the Stage 2, 5 tasks run on executor_2 and 3 tasks run on
executor_1
for _i in 0..5 {
@@ -2283,10 +2256,7 @@ mod test {
agg_graph.revive();
for attempt in 0..6 {
- if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- agg_graph.update_task_status(&executor1, vec![task_status], 4,
4)?;
- }
+ revive_graph_and_complete_next_stage(&mut agg_graph)?;
// 1rd task in the Stage 2, failed due to FetchPartitionError
if let Some(task1) = agg_graph.pop_next_task(&executor2.id)? {
@@ -2318,7 +2288,7 @@ mod test {
let running_stage = agg_graph.running_stages();
assert_eq!(running_stage.len(), 1);
assert_eq!(running_stage[0], 1);
- assert_eq!(agg_graph.available_tasks(), 1);
+ assert_eq!(agg_graph.available_tasks(), 2);
} else {
// Job is failed after exceeds the max_stage_failures
assert_eq!(stage_events.len(), 1);
@@ -2355,10 +2325,7 @@ mod test {
assert_eq!(agg_graph.stage_count(), 3);
// Complete the Stage 1
- if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
- }
+ revive_graph_and_complete_next_stage(&mut agg_graph)?;
// Complete the Stage 2, 5 tasks run on executor_2, 2 tasks run on
executor_1, 1 task runs on executor_3
for _i in 0..5 {
@@ -2559,10 +2526,7 @@ mod test {
assert_eq!(agg_graph.stage_count(), 3);
// Complete the Stage 1
- if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
- }
+ revive_graph_and_complete_next_stage(&mut agg_graph)?;
// Complete the Stage 2, 5 tasks run on executor_2, 3 tasks run on
executor_1
for _i in 0..5 {
@@ -2662,10 +2626,7 @@ mod test {
assert_eq!(agg_graph.stage_count(), 3);
// Complete the Stage 1
- if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
- }
+ revive_graph_and_complete_next_stage(&mut agg_graph)?;
// Complete the Stage 2, 5 tasks run on executor_2, 3 tasks run on
executor_1
for _i in 0..5 {
@@ -2735,7 +2696,7 @@ mod test {
let running_stage = agg_graph.running_stages();
assert_eq!(running_stage.len(), 1);
assert_eq!(running_stage[0], 1);
- assert_eq!(agg_graph.available_tasks(), 1);
+ assert_eq!(agg_graph.available_tasks(), 2);
// There are two failed stage attempts: Stage 2 and Stage 3
assert_eq!(agg_graph.failed_stage_attempts.len(), 2);
@@ -2759,14 +2720,9 @@ mod test {
let executor1 = mock_executor("executor-id1".to_string());
let executor2 = mock_executor("executor-id2".to_string());
let mut agg_graph = test_aggregation_plan(4).await;
- // Call revive to move the leaf Resolved stages to Running
- agg_graph.revive();
// Complete the Stage 1
- if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
- let task_status = mock_completed_task(task, &executor1.id);
- agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
- }
+ revive_graph_and_complete_next_stage(&mut agg_graph)?;
// 1st task in the Stage 2
let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
diff --git a/ballista/scheduler/src/state/execution_graph_dot.rs
b/ballista/scheduler/src/state/execution_graph_dot.rs
index 3a6dce7a..d5d7e7ae 100644
--- a/ballista/scheduler/src/state/execution_graph_dot.rs
+++ b/ballista/scheduler/src/state/execution_graph_dot.rs
@@ -432,13 +432,13 @@ mod tests {
let expected = r#"digraph G {
subgraph cluster0 {
label = "Stage 1 [Resolved]";
- stage_1_0 [shape=box, label="ShuffleWriter [0 partitions]"]
+ stage_1_0 [shape=box, label="ShuffleWriter [2 partitions]"]
stage_1_0_0 [shape=box, label="MemoryExec"]
stage_1_0_0 -> stage_1_0
}
subgraph cluster1 {
label = "Stage 2 [Resolved]";
- stage_2_0 [shape=box, label="ShuffleWriter [0 partitions]"]
+ stage_2_0 [shape=box, label="ShuffleWriter [2 partitions]"]
stage_2_0_0 [shape=box, label="MemoryExec"]
stage_2_0_0 -> stage_2_0
}
@@ -462,7 +462,7 @@ filter_expr="]
}
subgraph cluster3 {
label = "Stage 4 [Resolved]";
- stage_4_0 [shape=box, label="ShuffleWriter [0 partitions]"]
+ stage_4_0 [shape=box, label="ShuffleWriter [2 partitions]"]
stage_4_0_0 [shape=box, label="MemoryExec"]
stage_4_0_0 -> stage_4_0
}
@@ -531,19 +531,19 @@ filter_expr="]
let expected = r#"digraph G {
subgraph cluster0 {
label = "Stage 1 [Resolved]";
- stage_1_0 [shape=box, label="ShuffleWriter [0 partitions]"]
+ stage_1_0 [shape=box, label="ShuffleWriter [2 partitions]"]
stage_1_0_0 [shape=box, label="MemoryExec"]
stage_1_0_0 -> stage_1_0
}
subgraph cluster1 {
label = "Stage 2 [Resolved]";
- stage_2_0 [shape=box, label="ShuffleWriter [0 partitions]"]
+ stage_2_0 [shape=box, label="ShuffleWriter [2 partitions]"]
stage_2_0_0 [shape=box, label="MemoryExec"]
stage_2_0_0 -> stage_2_0
}
subgraph cluster2 {
label = "Stage 3 [Resolved]";
- stage_3_0 [shape=box, label="ShuffleWriter [0 partitions]"]
+ stage_3_0 [shape=box, label="ShuffleWriter [2 partitions]"]
stage_3_0_0 [shape=box, label="MemoryExec"]
stage_3_0_0 -> stage_3_0
}
@@ -635,7 +635,7 @@ filter_expr="]
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::UInt32, false),
]));
- let table = Arc::new(MemTable::try_new(schema.clone(), vec![])?);
+ let table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![],
vec![]])?);
ctx.register_table("foo", table.clone())?;
ctx.register_table("bar", table.clone())?;
ctx.register_table("baz", table)?;
@@ -660,7 +660,8 @@ filter_expr="]
let ctx = SessionContext::new_with_config(config);
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32,
false)]));
- let table = Arc::new(MemTable::try_new(schema.clone(), vec![])?);
+ // we specify the input partitions to be > 1 because of
https://github.com/apache/datafusion/issues/12611
+ let table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![],
vec![]])?);
ctx.register_table("foo", table.clone())?;
ctx.register_table("bar", table.clone())?;
ctx.register_table("baz", table)?;
diff --git a/ballista/scheduler/src/test_utils.rs
b/ballista/scheduler/src/test_utils.rs
index 59e6a875..5e5dee12 100644
--- a/ballista/scheduler/src/test_utils.rs
+++ b/ballista/scheduler/src/test_utils.rs
@@ -16,6 +16,7 @@
// under the License.
use ballista_core::error::{BallistaError, Result};
+use datafusion::catalog::Session;
use std::any::Any;
use std::collections::HashMap;
use std::future::Future;
@@ -44,19 +45,18 @@ use ballista_core::serde::{protobuf, BallistaCodec};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::common::DataFusionError;
use datafusion::datasource::{TableProvider, TableType};
-use datafusion::execution::context::{SessionConfig, SessionContext,
SessionState};
-use datafusion::functions_aggregate::sum::sum;
-use datafusion::logical_expr::expr::Sort;
-use datafusion::logical_expr::{Expr, LogicalPlan};
+use datafusion::execution::context::{SessionConfig, SessionContext};
+use datafusion::functions_aggregate::{count::count, sum::sum};
+use datafusion::logical_expr::{Expr, LogicalPlan, SortExpr};
use datafusion::physical_plan::display::DisplayableExecutionPlan;
use datafusion::physical_plan::ExecutionPlan;
-use datafusion::prelude::{col, count, CsvReadOptions, JoinType};
-use datafusion::test_util::scan_empty;
+use datafusion::prelude::{col, CsvReadOptions, JoinType};
+use datafusion::test_util::scan_empty_with_partitions;
use crate::cluster::BallistaCluster;
use crate::scheduler_server::event::QueryStageSchedulerEvent;
-use crate::state::execution_graph::{ExecutionGraph, TaskDescription};
+use crate::state::execution_graph::{ExecutionGraph, ExecutionStage,
TaskDescription};
use ballista_core::utils::default_session_builder;
use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
use parking_lot::Mutex;
@@ -89,7 +89,7 @@ impl TableProvider for ExplodingTableProvider {
async fn scan(
&self,
- _ctx: &SessionState,
+ _ctx: &dyn Session,
_projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
@@ -783,6 +783,47 @@ pub fn assert_failed_event(job_id: &str, collector:
&TestMetricsCollector) {
assert!(found, "{}", "Expected failed event for job {job_id}");
}
+pub fn revive_graph_and_complete_next_stage(graph: &mut ExecutionGraph) ->
Result<usize> {
+ let executor = mock_executor("executor-id1".to_string());
+ revive_graph_and_complete_next_stage_with_executor(graph, &executor)
+}
+
+pub fn revive_graph_and_complete_next_stage_with_executor(
+ graph: &mut ExecutionGraph,
+ executor: &ExecutorMetadata,
+) -> Result<usize> {
+ graph.revive();
+
+ // find the num_available_tasks of the next running stage
+ let num_available_tasks = graph
+ .stages()
+ .iter()
+ .map(|(_stage_id, stage)| {
+ if let ExecutionStage::Running(stage) = stage {
+ stage
+ .task_infos
+ .iter()
+ .filter(|info| info.is_none())
+ .count()
+ } else {
+ 0
+ }
+ })
+ .find(|num_available_tasks| num_available_tasks > &0)
+ .unwrap();
+
+ if num_available_tasks > 0 {
+ for _ in 0..num_available_tasks {
+ if let Some(task) = graph.pop_next_task(&executor.id).unwrap() {
+ let task_status = mock_completed_task(task, &executor.id);
+ graph.update_task_status(executor, vec![task_status], 1, 1)?;
+ }
+ }
+ }
+
+ Ok(num_available_tasks)
+}
+
pub async fn test_aggregation_plan(partition: usize) -> ExecutionGraph {
test_aggregation_plan_with_job_id(partition, "job").await
}
@@ -800,7 +841,8 @@ pub async fn test_aggregation_plan_with_job_id(
Field::new("gmv", DataType::UInt64, false),
]);
- let logical_plan = scan_empty(None, &schema, Some(vec![0, 1]))
+ // we specify the input partitions to be > 1 because of
https://github.com/apache/datafusion/issues/12611
+ let logical_plan = scan_empty_with_partitions(None, &schema, Some(vec![0,
1]), 2)
.unwrap()
.aggregate(vec![col("id")], vec![sum(col("gmv"))])
.unwrap()
@@ -833,7 +875,8 @@ pub async fn test_two_aggregations_plan(partition: usize)
-> ExecutionGraph {
Field::new("gmv", DataType::UInt64, false),
]);
- let logical_plan = scan_empty(None, &schema, Some(vec![0, 1, 2]))
+ // we specify the input partitions to be > 1 because of
https://github.com/apache/datafusion/issues/12611
+ let logical_plan = scan_empty_with_partitions(None, &schema, Some(vec![0,
1, 2]), 2)
.unwrap()
.aggregate(vec![col("id"), col("name")], vec![sum(col("gmv"))])
.unwrap()
@@ -867,7 +910,8 @@ pub async fn test_coalesce_plan(partition: usize) ->
ExecutionGraph {
Field::new("gmv", DataType::UInt64, false),
]);
- let logical_plan = scan_empty(None, &schema, Some(vec![0, 1]))
+ // we specify the input partitions to be > 1 because of
https://github.com/apache/datafusion/issues/12611
+ let logical_plan = scan_empty_with_partitions(None, &schema, Some(vec![0,
1]), 2)
.unwrap()
.limit(0, Some(1))
.unwrap()
@@ -898,14 +942,15 @@ pub async fn test_join_plan(partition: usize) ->
ExecutionGraph {
Field::new("gmv", DataType::UInt64, false),
]);
- let left_plan = scan_empty(Some("left"), &schema, None).unwrap();
+ // we specify the input partitions to be > 1 because of
https://github.com/apache/datafusion/issues/12611
+ let left_plan = scan_empty_with_partitions(Some("left"), &schema, None,
2).unwrap();
- let right_plan = scan_empty(Some("right"), &schema, None)
+ let right_plan = scan_empty_with_partitions(Some("right"), &schema, None,
2)
.unwrap()
.build()
.unwrap();
- let sort_expr = Expr::Sort(Sort::new(Box::new(col("id")), false, false));
+ let sort_expr = Expr::Sort(SortExpr::new(Box::new(col("id")), false,
false));
let logical_plan = left_plan
.join(right_plan, JoinType::Inner, (vec!["id"], vec!["id"]), None)
diff --git a/examples/Cargo.toml b/examples/Cargo.toml
index e41e6051..3fd07740 100644
--- a/examples/Cargo.toml
+++ b/examples/Cargo.toml
@@ -38,7 +38,7 @@ ballista = { path = "../ballista/client", version = "0.12.0" }
datafusion = { workspace = true }
futures = "0.3"
num_cpus = "1.13.0"
-prost = "0.12"
+prost = { workspace = true }
tokio = { version = "1.0", features = [
"macros",
"rt",
@@ -46,4 +46,4 @@ tokio = { version = "1.0", features = [
"sync",
"parking_lot"
] }
-tonic = "0.10"
+tonic = { workspace = true }
diff --git a/python/.cargo/config.toml b/python/.cargo/config.toml
new file mode 100644
index 00000000..d47f983e
--- /dev/null
+++ b/python/.cargo/config.toml
@@ -0,0 +1,11 @@
+[target.x86_64-apple-darwin]
+rustflags = [
+ "-C", "link-arg=-undefined",
+ "-C", "link-arg=dynamic_lookup",
+]
+
+[target.aarch64-apple-darwin]
+rustflags = [
+ "-C", "link-arg=-undefined",
+ "-C", "link-arg=dynamic_lookup",
+]
diff --git a/python/Cargo.toml b/python/Cargo.toml
index 2b2dff41..eb662cb1 100644
--- a/python/Cargo.toml
+++ b/python/Cargo.toml
@@ -33,13 +33,12 @@ publish = false
async-trait = "0.1.77"
ballista = { path = "../ballista/client", version = "0.12.0" }
ballista-core = { path = "../ballista/core", version = "0.12.0" }
-datafusion = "35.0.0"
-datafusion-proto = "35.0.0"
+datafusion = "41.0.0"
+datafusion-proto = "41.0.0"
+datafusion-python = "41.0.0"
-# we need to use a recent build of ADP that has a public PyDataFrame
-datafusion-python = { git =
"https://github.com/apache/arrow-datafusion-python", rev =
"5296c0cfcf8e6fcb654d5935252469bf04f929e9" }
-
-pyo3 = { version = "0.20", features = ["extension-module", "abi3",
"abi3-py38"] }
+pyo3 = { version = "0.21", features = ["extension-module", "abi3",
"abi3-py38"] }
+pyo3-log = "0.11.0"
tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread",
"sync"] }
[lib]
diff --git a/python/src/context.rs b/python/src/context.rs
index 0d0231c6..be7dd610 100644
--- a/python/src/context.rs
+++ b/python/src/context.rs
@@ -16,6 +16,7 @@
// under the License.
use crate::utils::to_pyerr;
+use datafusion::logical_expr::SortExpr;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use std::path::PathBuf;
@@ -30,7 +31,7 @@ use datafusion_python::context::{
};
use datafusion_python::dataframe::PyDataFrame;
use datafusion_python::errors::DataFusionError;
-use datafusion_python::expr::PyExpr;
+use datafusion_python::expr::sort_expr::PySortExpr;
use datafusion_python::sql::logical::PyLogicalPlan;
use datafusion_python::utils::wait_for_future;
@@ -187,7 +188,7 @@ impl PySessionContext {
file_extension: &str,
skip_metadata: bool,
schema: Option<PyArrowType<Schema>>,
- file_sort_order: Option<Vec<Vec<PyExpr>>>,
+ file_sort_order: Option<Vec<Vec<PySortExpr>>>,
py: Python,
) -> PyResult<PyDataFrame> {
let mut options = ParquetReadOptions::default()
@@ -199,7 +200,14 @@ impl PySessionContext {
options.file_sort_order = file_sort_order
.unwrap_or_default()
.into_iter()
- .map(|e| e.into_iter().map(|f| f.into()).collect())
+ .map(|e| {
+ e.into_iter()
+ .map(|f| {
+ let sort_expr: SortExpr = f.into();
+ *sort_expr.expr
+ })
+ .collect()
+ })
.collect();
let result = self.ctx.read_parquet(path, options);
@@ -299,7 +307,7 @@ impl PySessionContext {
file_extension: &str,
skip_metadata: bool,
schema: Option<PyArrowType<Schema>>,
- file_sort_order: Option<Vec<Vec<PyExpr>>>,
+ file_sort_order: Option<Vec<Vec<PySortExpr>>>,
py: Python,
) -> PyResult<()> {
let mut options = ParquetReadOptions::default()
@@ -311,7 +319,14 @@ impl PySessionContext {
options.file_sort_order = file_sort_order
.unwrap_or_default()
.into_iter()
- .map(|e| e.into_iter().map(|f| f.into()).collect())
+ .map(|e| {
+ e.into_iter()
+ .map(|f| {
+ let sort_expr: SortExpr = f.into();
+ *sort_expr.expr
+ })
+ .collect()
+ })
.collect();
let result = self.ctx.register_parquet(name, path, options);
diff --git a/python/src/lib.rs b/python/src/lib.rs
index 04cf232a..5fbd2491 100644
--- a/python/src/lib.rs
+++ b/python/src/lib.rs
@@ -22,7 +22,8 @@ mod utils;
pub use crate::context::PySessionContext;
#[pymodule]
-fn pyballista_internal(_py: Python, m: &PyModule) -> PyResult<()> {
+fn pyballista_internal(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
+ pyo3_log::init();
// Ballista structs
m.add_class::<PySessionContext>()?;
// DataFusion structs
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]