This is an automated email from the ASF dual-hosted git repository.
milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git
The following commit(s) were added to refs/heads/main by this push:
new 62932e7c fix: shuffle reader should return statistics (#1302)
62932e7c is described below
commit 62932e7ca21d59487ddb4f0c6fea873ba15741e2
Author: Marko Milenković <[email protected]>
AuthorDate: Mon Sep 1 20:18:30 2025 +0100
fix: shuffle reader should return statistics (#1302)
---
ballista/client/tests/context_checks.rs | 39 ++++
.../core/src/execution_plans/shuffle_reader.rs | 224 ++++++++++++++++++++-
ballista/scheduler/src/state/execution_graph.rs | 8 +-
ballista/scheduler/src/state/execution_stage.rs | 12 +-
4 files changed, 265 insertions(+), 18 deletions(-)
diff --git a/ballista/client/tests/context_checks.rs
b/ballista/client/tests/context_checks.rs
index 125b26b3..1d1c710a 100644
--- a/ballista/client/tests/context_checks.rs
+++ b/ballista/client/tests/context_checks.rs
@@ -743,4 +743,43 @@ mod supported {
Ok(())
}
+
+ #[rstest]
+ #[case::standalone(standalone_context())]
+ #[case::remote(remote_context())]
+ #[case::standalone_state(standalone_context_with_state())]
+ #[case::remote_state(remote_context_with_state())]
+ #[tokio::test]
+ async fn should_execute_group_by(
+ #[future(awt)]
+ #[case]
+ ctx: SessionContext,
+ test_data: String,
+ ) -> datafusion::error::Result<()> {
+ ctx.register_parquet(
+ "test",
+ &format!("{test_data}/alltypes_plain.parquet"),
+ Default::default(),
+ )
+ .await?;
+
+ let expected = [
+ "+------------+----------+",
+ "| string_col | count(*) |",
+ "+------------+----------+",
+ "| 30 | 1 |",
+ "| 31 | 2 |",
+ "+------------+----------+",
+ ];
+
+ let result = ctx
+ .sql("select string_col, count(*) from test where id > 4 group by
string_col order by string_col")
+ .await?
+ .collect()
+ .await?;
+
+ assert_batches_eq!(expected, &result);
+
+ Ok(())
+ }
}
diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs
b/ballista/core/src/execution_plans/shuffle_reader.rs
index caff9adf..b0c0f04f 100644
--- a/ballista/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/core/src/execution_plans/shuffle_reader.rs
@@ -49,7 +49,7 @@ use crate::error::BallistaError;
use datafusion::execution::context::TaskContext;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use itertools::Itertools;
-use log::{debug, error};
+use log::{debug, error, trace};
use rand::prelude::SliceRandom;
use rand::rng;
use tokio::sync::{mpsc, Semaphore};
@@ -180,7 +180,6 @@ impl ExecutionPlan for ShuffleReaderExec {
.collect();
// Shuffle partitions for evenly send fetching partition requests to
avoid hot executors within multiple tasks
partition_locations.shuffle(&mut rng());
-
let response_receiver =
send_fetch_partitions(partition_locations, max_request_num,
max_message_size);
@@ -195,17 +194,75 @@ impl ExecutionPlan for ShuffleReaderExec {
Some(self.metrics.clone_inner())
}
- fn statistics(&self) -> Result<Statistics> {
- Ok(stats_for_partitions(
- self.schema.fields().len(),
- self.partition
- .iter()
- .flatten()
- .map(|loc| loc.partition_stats),
- ))
+ fn partition_statistics(&self, partition: Option<usize>) ->
Result<Statistics> {
+ if let Some(idx) = partition {
+ let partition_count =
self.properties().partitioning.partition_count();
+ if idx >= partition_count {
+ return datafusion::common::internal_err!(
+ "Invalid partition index: {}, the partition count is {}",
+ idx,
+ partition_count
+ );
+ }
+ let stat_for_partition =
+ stats_for_partition(idx, self.schema.fields().len(),
&self.partition);
+
+ trace!(
+ "shuffle reader at stage: {} and partition {} returned
statistics: {:?}",
+ self.stage_id,
+ idx,
+ stat_for_partition
+ );
+ stat_for_partition
+ } else {
+ let stats_for_partitions = stats_for_partitions(
+ self.schema.fields().len(),
+ self.partition
+ .iter()
+ .flatten()
+ .map(|loc| loc.partition_stats),
+ );
+ trace!("shuffle reader at stage: {} returned statistics for all
partitions: {:?}", self.stage_id, stats_for_partitions);
+ Ok(stats_for_partitions)
+ }
}
}
+fn stats_for_partition(
+ partition: usize,
+ num_fields: usize,
+ partition_locations: &[Vec<PartitionLocation>],
+) -> Result<Statistics> {
+ // TODO stats: add column statistics to PartitionStats
+ let (num_rows, total_byte_size) = partition_locations
+ .iter()
+ .map(|location| {
+ // extract requested partitions
+ location
+ .get(partition)
+ .map(|p| p.partition_stats)
+ .map(|p| (p.num_rows, p.num_bytes))
+ .unwrap_or_default()
+ })
+ .fold(
+ (Some(0), Some(0)),
+ |(num_rows, total_byte_size), (rows, bytes)| {
+ (
+ num_rows.zip(rows).map(|(a, b)| a + b as usize),
+ total_byte_size.zip(bytes).map(|(a, b)| a + b as usize),
+ )
+ },
+ );
+
+ Ok(Statistics {
+ num_rows: num_rows.map(Precision::Exact).unwrap_or(Precision::Absent),
+ total_byte_size: total_byte_size
+ .map(Precision::Exact)
+ .unwrap_or(Precision::Absent),
+ column_statistics: vec![ColumnStatistics::new_unknown(); num_fields],
+ })
+}
+
fn stats_for_partitions(
num_fields: usize,
partition_stats: impl Iterator<Item = PartitionStats>,
@@ -220,7 +277,6 @@ fn stats_for_partitions(
.map(|(a, b)| a + b as usize);
(num_rows, total_byte_size)
});
-
Statistics {
num_rows: num_rows.map(Precision::Exact).unwrap_or(Precision::Absent),
total_byte_size: total_byte_size
@@ -543,6 +599,152 @@ mod tests {
assert_eq!(result, exptected);
}
+ #[tokio::test]
+ async fn test_stats_for_partition_statistics_no_specific_partition() ->
Result<()> {
+ let schema = Schema::new(vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Int32, false),
+ Field::new("c", DataType::Int32, false),
+ ]);
+
+ let job_id = "test_job_1";
+ let input_stage_id = 2;
+ let mut partitions: Vec<PartitionLocation> = vec![];
+ for partition_id in 0..4 {
+ partitions.push(PartitionLocation {
+ map_partition_id: 0,
+ partition_id: PartitionId {
+ job_id: job_id.to_string(),
+ stage_id: input_stage_id,
+ partition_id,
+ },
+ executor_meta: ExecutorMetadata {
+ id: "executor_1".to_string(),
+ host: "executor_1".to_string(),
+ port: 7070,
+ grpc_port: 8080,
+ specification: ExecutorSpecification { task_slots: 1 },
+ },
+ partition_stats: PartitionStats {
+ num_rows: Some(1),
+ num_batches: None,
+ num_bytes: Some(10),
+ },
+ path: "test_path".to_string(),
+ })
+ }
+
+ let shuffle_reader_exec = ShuffleReaderExec::try_new(
+ input_stage_id,
+ vec![partitions.clone(), partitions],
+ Arc::new(schema),
+ Partitioning::UnknownPartitioning(4),
+ )?;
+
+ let stats = shuffle_reader_exec.partition_statistics(None)?;
+ assert_eq!(8, *stats.num_rows.get_value().unwrap());
+ assert_eq!(80, *stats.total_byte_size.get_value().unwrap());
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_stats_for_partition_statistics_specific_partition() ->
Result<()> {
+ let schema = Schema::new(vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Int32, false),
+ Field::new("c", DataType::Int32, false),
+ ]);
+
+ let job_id = "test_job_1";
+ let input_stage_id = 2;
+ let mut partitions: Vec<PartitionLocation> = vec![];
+ for partition_id in 0..4 {
+ partitions.push(PartitionLocation {
+ map_partition_id: 0,
+ partition_id: PartitionId {
+ job_id: job_id.to_string(),
+ stage_id: input_stage_id,
+ partition_id,
+ },
+ executor_meta: ExecutorMetadata {
+ id: "executor_1".to_string(),
+ host: "executor_1".to_string(),
+ port: 7070,
+ grpc_port: 8080,
+ specification: ExecutorSpecification { task_slots: 1 },
+ },
+ partition_stats: PartitionStats {
+ num_rows: Some(1),
+ num_batches: None,
+ num_bytes: Some(10),
+ },
+ path: "test_path".to_string(),
+ })
+ }
+
+ let shuffle_reader_exec = ShuffleReaderExec::try_new(
+ input_stage_id,
+ vec![partitions.clone(), partitions],
+ Arc::new(schema),
+ Partitioning::UnknownPartitioning(4),
+ )?;
+
+ let stats = shuffle_reader_exec.partition_statistics(Some(3))?;
+ assert_eq!(2, *stats.num_rows.get_value().unwrap());
+ assert_eq!(20, *stats.total_byte_size.get_value().unwrap());
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn
test_stats_for_partition_statistics_specific_partition_out_of_range(
+ ) -> Result<()> {
+ let schema = Schema::new(vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Int32, false),
+ Field::new("c", DataType::Int32, false),
+ ]);
+
+ let job_id = "test_job_1";
+ let input_stage_id = 2;
+ let mut partitions: Vec<PartitionLocation> = vec![];
+ for partition_id in 0..4 {
+ partitions.push(PartitionLocation {
+ map_partition_id: 0,
+ partition_id: PartitionId {
+ job_id: job_id.to_string(),
+ stage_id: input_stage_id,
+ partition_id,
+ },
+ executor_meta: ExecutorMetadata {
+ id: "executor_1".to_string(),
+ host: "executor_1".to_string(),
+ port: 7070,
+ grpc_port: 8080,
+ specification: ExecutorSpecification { task_slots: 1 },
+ },
+ partition_stats: PartitionStats {
+ num_rows: Some(1),
+ num_batches: None,
+ num_bytes: Some(10),
+ },
+ path: "test_path".to_string(),
+ })
+ }
+
+ let shuffle_reader_exec = ShuffleReaderExec::try_new(
+ input_stage_id,
+ vec![partitions.clone(), partitions],
+ Arc::new(schema),
+ Partitioning::UnknownPartitioning(4),
+ )?;
+
+ let stats = shuffle_reader_exec.partition_statistics(Some(4));
+ assert!(stats.is_err());
+
+ Ok(())
+ }
#[tokio::test]
async fn test_fetch_partitions_error_mapping() -> Result<()> {
diff --git a/ballista/scheduler/src/state/execution_graph.rs
b/ballista/scheduler/src/state/execution_graph.rs
index c5cfbe9d..fb5d470d 100644
--- a/ballista/scheduler/src/state/execution_graph.rs
+++ b/ballista/scheduler/src/state/execution_graph.rs
@@ -1142,8 +1142,12 @@ impl ExecutionGraph {
/// Convert unresolved stage to be resolved
pub fn resolve_stage(&mut self, stage_id: usize) -> Result<bool> {
if let Some(ExecutionStage::UnResolved(stage)) =
self.stages.remove(&stage_id) {
- self.stages
- .insert(stage_id,
ExecutionStage::Resolved(stage.to_resolved()?));
+ self.stages.insert(
+ stage_id,
+ ExecutionStage::Resolved(
+ stage.to_resolved(self.session_config.options())?,
+ ),
+ );
Ok(true)
} else {
warn!(
diff --git a/ballista/scheduler/src/state/execution_stage.rs
b/ballista/scheduler/src/state/execution_stage.rs
index 2f2ef0e8..59e068c6 100644
--- a/ballista/scheduler/src/state/execution_stage.rs
+++ b/ballista/scheduler/src/state/execution_stage.rs
@@ -21,6 +21,7 @@ use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
+use datafusion::config::ConfigOptions;
use datafusion::physical_optimizer::aggregate_statistics::AggregateStatistics;
use datafusion::physical_optimizer::join_selection::JoinSelection;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
@@ -350,7 +351,7 @@ impl UnresolvedStage {
}
/// Change to the resolved state
- pub fn to_resolved(&self) -> Result<ResolvedStage> {
+ pub fn to_resolved(&self, options: &ConfigOptions) ->
Result<ResolvedStage> {
let input_locations = self
.inputs
.iter()
@@ -361,13 +362,14 @@ impl UnresolvedStage {
&input_locations,
)?;
+ //
+ // TODO: with datafusion 50 we can add rule to switch between HashJoin
and SortMergeJoin
+ //
let optimize_join = JoinSelection::new();
- let config = SessionConfig::default();
- let plan = optimize_join.optimize(plan, config.options())?;
+ let plan = optimize_join.optimize(plan, options)?;
let optimize_aggregate = AggregateStatistics::new();
- let plan =
- optimize_aggregate.optimize(plan,
SessionConfig::default().options())?;
+ let plan = optimize_aggregate.optimize(plan, options)?;
Ok(ResolvedStage::new(
self.stage_id,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]