This is an automated email from the ASF dual-hosted git repository.

milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 62932e7c fix: shuffle reader should return statistics (#1302)
62932e7c is described below

commit 62932e7ca21d59487ddb4f0c6fea873ba15741e2
Author: Marko Milenković <[email protected]>
AuthorDate: Mon Sep 1 20:18:30 2025 +0100

    fix: shuffle reader should return statistics (#1302)
---
 ballista/client/tests/context_checks.rs            |  39 ++++
 .../core/src/execution_plans/shuffle_reader.rs     | 224 ++++++++++++++++++++-
 ballista/scheduler/src/state/execution_graph.rs    |   8 +-
 ballista/scheduler/src/state/execution_stage.rs    |  12 +-
 4 files changed, 265 insertions(+), 18 deletions(-)

diff --git a/ballista/client/tests/context_checks.rs 
b/ballista/client/tests/context_checks.rs
index 125b26b3..1d1c710a 100644
--- a/ballista/client/tests/context_checks.rs
+++ b/ballista/client/tests/context_checks.rs
@@ -743,4 +743,43 @@ mod supported {
 
         Ok(())
     }
+
+    #[rstest]
+    #[case::standalone(standalone_context())]
+    #[case::remote(remote_context())]
+    #[case::standalone_state(standalone_context_with_state())]
+    #[case::remote_state(remote_context_with_state())]
+    #[tokio::test]
+    async fn should_execute_group_by(
+        #[future(awt)]
+        #[case]
+        ctx: SessionContext,
+        test_data: String,
+    ) -> datafusion::error::Result<()> {
+        ctx.register_parquet(
+            "test",
+            &format!("{test_data}/alltypes_plain.parquet"),
+            Default::default(),
+        )
+        .await?;
+
+        let expected = [
+            "+------------+----------+",
+            "| string_col | count(*) |",
+            "+------------+----------+",
+            "| 30         | 1        |",
+            "| 31         | 2        |",
+            "+------------+----------+",
+        ];
+
+        let result = ctx
+            .sql("select string_col, count(*) from test where id > 4 group by 
string_col order by string_col")
+            .await?
+            .collect()
+            .await?;
+
+        assert_batches_eq!(expected, &result);
+
+        Ok(())
+    }
 }
diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs 
b/ballista/core/src/execution_plans/shuffle_reader.rs
index caff9adf..b0c0f04f 100644
--- a/ballista/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/core/src/execution_plans/shuffle_reader.rs
@@ -49,7 +49,7 @@ use crate::error::BallistaError;
 use datafusion::execution::context::TaskContext;
 use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
 use itertools::Itertools;
-use log::{debug, error};
+use log::{debug, error, trace};
 use rand::prelude::SliceRandom;
 use rand::rng;
 use tokio::sync::{mpsc, Semaphore};
@@ -180,7 +180,6 @@ impl ExecutionPlan for ShuffleReaderExec {
             .collect();
         // Shuffle partitions for evenly send fetching partition requests to 
avoid hot executors within multiple tasks
         partition_locations.shuffle(&mut rng());
-
         let response_receiver =
             send_fetch_partitions(partition_locations, max_request_num, 
max_message_size);
 
@@ -195,17 +194,75 @@ impl ExecutionPlan for ShuffleReaderExec {
         Some(self.metrics.clone_inner())
     }
 
-    fn statistics(&self) -> Result<Statistics> {
-        Ok(stats_for_partitions(
-            self.schema.fields().len(),
-            self.partition
-                .iter()
-                .flatten()
-                .map(|loc| loc.partition_stats),
-        ))
+    fn partition_statistics(&self, partition: Option<usize>) -> 
Result<Statistics> {
+        if let Some(idx) = partition {
+            let partition_count = 
self.properties().partitioning.partition_count();
+            if idx >= partition_count {
+                return datafusion::common::internal_err!(
+                    "Invalid partition index: {}, the partition count is {}",
+                    idx,
+                    partition_count
+                );
+            }
+            let stat_for_partition =
+                stats_for_partition(idx, self.schema.fields().len(), 
&self.partition);
+
+            trace!(
+                "shuffle reader at stage: {} and partition {} returned 
statistics: {:?}",
+                self.stage_id,
+                idx,
+                stat_for_partition
+            );
+            stat_for_partition
+        } else {
+            let stats_for_partitions = stats_for_partitions(
+                self.schema.fields().len(),
+                self.partition
+                    .iter()
+                    .flatten()
+                    .map(|loc| loc.partition_stats),
+            );
+            trace!("shuffle reader at stage: {} returned statistics for all 
partitions: {:?}", self.stage_id, stats_for_partitions);
+            Ok(stats_for_partitions)
+        }
     }
 }
 
+fn stats_for_partition(
+    partition: usize,
+    num_fields: usize,
+    partition_locations: &[Vec<PartitionLocation>],
+) -> Result<Statistics> {
+    // TODO stats: add column statistics to PartitionStats
+    let (num_rows, total_byte_size) = partition_locations
+        .iter()
+        .map(|location| {
+            // extract requested partitions
+            location
+                .get(partition)
+                .map(|p| p.partition_stats)
+                .map(|p| (p.num_rows, p.num_bytes))
+                .unwrap_or_default()
+        })
+        .fold(
+            (Some(0), Some(0)),
+            |(num_rows, total_byte_size), (rows, bytes)| {
+                (
+                    num_rows.zip(rows).map(|(a, b)| a + b as usize),
+                    total_byte_size.zip(bytes).map(|(a, b)| a + b as usize),
+                )
+            },
+        );
+
+    Ok(Statistics {
+        num_rows: num_rows.map(Precision::Exact).unwrap_or(Precision::Absent),
+        total_byte_size: total_byte_size
+            .map(Precision::Exact)
+            .unwrap_or(Precision::Absent),
+        column_statistics: vec![ColumnStatistics::new_unknown(); num_fields],
+    })
+}
+
 fn stats_for_partitions(
     num_fields: usize,
     partition_stats: impl Iterator<Item = PartitionStats>,
@@ -220,7 +277,6 @@ fn stats_for_partitions(
                 .map(|(a, b)| a + b as usize);
             (num_rows, total_byte_size)
         });
-
     Statistics {
         num_rows: num_rows.map(Precision::Exact).unwrap_or(Precision::Absent),
         total_byte_size: total_byte_size
@@ -543,6 +599,152 @@ mod tests {
 
         assert_eq!(result, exptected);
     }
+    #[tokio::test]
+    async fn test_stats_for_partition_statistics_no_specific_partition() -> 
Result<()> {
+        let schema = Schema::new(vec![
+            Field::new("a", DataType::Int32, false),
+            Field::new("b", DataType::Int32, false),
+            Field::new("c", DataType::Int32, false),
+        ]);
+
+        let job_id = "test_job_1";
+        let input_stage_id = 2;
+        let mut partitions: Vec<PartitionLocation> = vec![];
+        for partition_id in 0..4 {
+            partitions.push(PartitionLocation {
+                map_partition_id: 0,
+                partition_id: PartitionId {
+                    job_id: job_id.to_string(),
+                    stage_id: input_stage_id,
+                    partition_id,
+                },
+                executor_meta: ExecutorMetadata {
+                    id: "executor_1".to_string(),
+                    host: "executor_1".to_string(),
+                    port: 7070,
+                    grpc_port: 8080,
+                    specification: ExecutorSpecification { task_slots: 1 },
+                },
+                partition_stats: PartitionStats {
+                    num_rows: Some(1),
+                    num_batches: None,
+                    num_bytes: Some(10),
+                },
+                path: "test_path".to_string(),
+            })
+        }
+
+        let shuffle_reader_exec = ShuffleReaderExec::try_new(
+            input_stage_id,
+            vec![partitions.clone(), partitions],
+            Arc::new(schema),
+            Partitioning::UnknownPartitioning(4),
+        )?;
+
+        let stats = shuffle_reader_exec.partition_statistics(None)?;
+        assert_eq!(8, *stats.num_rows.get_value().unwrap());
+        assert_eq!(80, *stats.total_byte_size.get_value().unwrap());
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_stats_for_partition_statistics_specific_partition() -> 
Result<()> {
+        let schema = Schema::new(vec![
+            Field::new("a", DataType::Int32, false),
+            Field::new("b", DataType::Int32, false),
+            Field::new("c", DataType::Int32, false),
+        ]);
+
+        let job_id = "test_job_1";
+        let input_stage_id = 2;
+        let mut partitions: Vec<PartitionLocation> = vec![];
+        for partition_id in 0..4 {
+            partitions.push(PartitionLocation {
+                map_partition_id: 0,
+                partition_id: PartitionId {
+                    job_id: job_id.to_string(),
+                    stage_id: input_stage_id,
+                    partition_id,
+                },
+                executor_meta: ExecutorMetadata {
+                    id: "executor_1".to_string(),
+                    host: "executor_1".to_string(),
+                    port: 7070,
+                    grpc_port: 8080,
+                    specification: ExecutorSpecification { task_slots: 1 },
+                },
+                partition_stats: PartitionStats {
+                    num_rows: Some(1),
+                    num_batches: None,
+                    num_bytes: Some(10),
+                },
+                path: "test_path".to_string(),
+            })
+        }
+
+        let shuffle_reader_exec = ShuffleReaderExec::try_new(
+            input_stage_id,
+            vec![partitions.clone(), partitions],
+            Arc::new(schema),
+            Partitioning::UnknownPartitioning(4),
+        )?;
+
+        let stats = shuffle_reader_exec.partition_statistics(Some(3))?;
+        assert_eq!(2, *stats.num_rows.get_value().unwrap());
+        assert_eq!(20, *stats.total_byte_size.get_value().unwrap());
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn 
test_stats_for_partition_statistics_specific_partition_out_of_range(
+    ) -> Result<()> {
+        let schema = Schema::new(vec![
+            Field::new("a", DataType::Int32, false),
+            Field::new("b", DataType::Int32, false),
+            Field::new("c", DataType::Int32, false),
+        ]);
+
+        let job_id = "test_job_1";
+        let input_stage_id = 2;
+        let mut partitions: Vec<PartitionLocation> = vec![];
+        for partition_id in 0..4 {
+            partitions.push(PartitionLocation {
+                map_partition_id: 0,
+                partition_id: PartitionId {
+                    job_id: job_id.to_string(),
+                    stage_id: input_stage_id,
+                    partition_id,
+                },
+                executor_meta: ExecutorMetadata {
+                    id: "executor_1".to_string(),
+                    host: "executor_1".to_string(),
+                    port: 7070,
+                    grpc_port: 8080,
+                    specification: ExecutorSpecification { task_slots: 1 },
+                },
+                partition_stats: PartitionStats {
+                    num_rows: Some(1),
+                    num_batches: None,
+                    num_bytes: Some(10),
+                },
+                path: "test_path".to_string(),
+            })
+        }
+
+        let shuffle_reader_exec = ShuffleReaderExec::try_new(
+            input_stage_id,
+            vec![partitions.clone(), partitions],
+            Arc::new(schema),
+            Partitioning::UnknownPartitioning(4),
+        )?;
+
+        let stats = shuffle_reader_exec.partition_statistics(Some(4));
+        assert!(stats.is_err());
+
+        Ok(())
+    }
 
     #[tokio::test]
     async fn test_fetch_partitions_error_mapping() -> Result<()> {
diff --git a/ballista/scheduler/src/state/execution_graph.rs 
b/ballista/scheduler/src/state/execution_graph.rs
index c5cfbe9d..fb5d470d 100644
--- a/ballista/scheduler/src/state/execution_graph.rs
+++ b/ballista/scheduler/src/state/execution_graph.rs
@@ -1142,8 +1142,12 @@ impl ExecutionGraph {
     /// Convert unresolved stage to be resolved
     pub fn resolve_stage(&mut self, stage_id: usize) -> Result<bool> {
         if let Some(ExecutionStage::UnResolved(stage)) = 
self.stages.remove(&stage_id) {
-            self.stages
-                .insert(stage_id, 
ExecutionStage::Resolved(stage.to_resolved()?));
+            self.stages.insert(
+                stage_id,
+                ExecutionStage::Resolved(
+                    stage.to_resolved(self.session_config.options())?,
+                ),
+            );
             Ok(true)
         } else {
             warn!(
diff --git a/ballista/scheduler/src/state/execution_stage.rs 
b/ballista/scheduler/src/state/execution_stage.rs
index 2f2ef0e8..59e068c6 100644
--- a/ballista/scheduler/src/state/execution_stage.rs
+++ b/ballista/scheduler/src/state/execution_stage.rs
@@ -21,6 +21,7 @@ use std::fmt::{Debug, Formatter};
 use std::sync::Arc;
 use std::time::{SystemTime, UNIX_EPOCH};
 
+use datafusion::config::ConfigOptions;
 use datafusion::physical_optimizer::aggregate_statistics::AggregateStatistics;
 use datafusion::physical_optimizer::join_selection::JoinSelection;
 use datafusion::physical_optimizer::PhysicalOptimizerRule;
@@ -350,7 +351,7 @@ impl UnresolvedStage {
     }
 
     /// Change to the resolved state
-    pub fn to_resolved(&self) -> Result<ResolvedStage> {
+    pub fn to_resolved(&self, options: &ConfigOptions) -> 
Result<ResolvedStage> {
         let input_locations = self
             .inputs
             .iter()
@@ -361,13 +362,14 @@ impl UnresolvedStage {
             &input_locations,
         )?;
 
+        //
+        // TODO: with datafusion 50 we can add rule to switch between HashJoin 
and SortMergeJoin
+        //
         let optimize_join = JoinSelection::new();
-        let config = SessionConfig::default();
-        let plan = optimize_join.optimize(plan, config.options())?;
+        let plan = optimize_join.optimize(plan, options)?;
 
         let optimize_aggregate = AggregateStatistics::new();
-        let plan =
-            optimize_aggregate.optimize(plan, 
SessionConfig::default().options())?;
+        let plan = optimize_aggregate.optimize(plan, options)?;
 
         Ok(ResolvedStage::new(
             self.stage_id,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to