This is an automated email from the ASF dual-hosted git repository.

milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 7d9fcd0f Ensure stage-level sort requirements are enforced in 
distributed planning (#1306)
7d9fcd0f is described below

commit 7d9fcd0fafd914e7a67b4ebeae31abd460803e91
Author: mete <[email protected]>
AuthorDate: Mon Sep 1 19:02:05 2025 +0300

    Ensure stage-level sort requirements are enforced in distributed planning 
(#1306)
    
    * Enforce sort in stages and some tests
    
    * Cargo toml format
---
 ballista/client/Cargo.toml                        |   3 +
 ballista/client/testdata/bug_1296/date_dim.csv    |  11 ++
 ballista/client/testdata/bug_1296/item.csv        |   6 +
 ballista/client/testdata/bug_1296/store.csv       |   4 +
 ballista/client/testdata/bug_1296/store_sales.csv |  16 +++
 ballista/client/tests/bugs.rs                     | 147 ++++++++++++++++++++++
 ballista/scheduler/src/planner.rs                 |  60 ++++++---
 ballista/scheduler/src/state/execution_graph.rs   |   3 +-
 8 files changed, 235 insertions(+), 15 deletions(-)

diff --git a/ballista/client/Cargo.toml b/ballista/client/Cargo.toml
index 38e83b76..a64e248a 100644
--- a/ballista/client/Cargo.toml
+++ b/ballista/client/Cargo.toml
@@ -51,3 +51,6 @@ tonic = { workspace = true }
 [features]
 default = ["standalone"]
 standalone = ["ballista-executor", "ballista-scheduler"]
+# tests which need change of RUST_MIN_STACK in order for 
+# tests to run. 
+test_extended_stack = []
diff --git a/ballista/client/testdata/bug_1296/date_dim.csv 
b/ballista/client/testdata/bug_1296/date_dim.csv
new file mode 100644
index 00000000..12dd1d19
--- /dev/null
+++ b/ballista/client/testdata/bug_1296/date_dim.csv
@@ -0,0 +1,11 @@
+d_date_sk,d_date_id,d_date,d_year,d_moy,d_month_name
+1,19981201,1998-12-01,1998,12,December
+2,19981215,1998-12-15,1998,12,December
+3,19990105,1999-01-05,1999,1,January
+4,19990210,1999-02-10,1999,2,February
+5,19990315,1999-03-15,1999,3,March
+6,19990420,1999-04-20,1999,4,April
+7,19990525,1999-05-25,1999,5,May
+8,19990630,1999-06-30,1999,6,June
+9,20000105,2000-01-05,2000,1,January
+10,20000120,2000-01-20,2000,1,January
\ No newline at end of file
diff --git a/ballista/client/testdata/bug_1296/item.csv 
b/ballista/client/testdata/bug_1296/item.csv
new file mode 100644
index 00000000..3818fabc
--- /dev/null
+++ b/ballista/client/testdata/bug_1296/item.csv
@@ -0,0 +1,6 @@
+i_item_sk,i_item_id,i_item_name,i_category,i_brand,i_price
+20,ITM001,Laptop Pro,Electronics,TechBrand,999.99
+21,ITM002,Wireless Headphones,Electronics,SoundTech,199.99
+22,ITM003,Cotton T-Shirt,Clothing,FashionNow,19.99
+23,ITM004,Running Shoes,Footwear,SpeedWay,89.99
+24,ITM005,Coffee Maker,Appliances,HomeGoods,79.99
\ No newline at end of file
diff --git a/ballista/client/testdata/bug_1296/store.csv 
b/ballista/client/testdata/bug_1296/store.csv
new file mode 100644
index 00000000..27c7080e
--- /dev/null
+++ b/ballista/client/testdata/bug_1296/store.csv
@@ -0,0 +1,4 @@
+s_store_sk,s_store_id,s_store_name,s_company_name,s_city,s_state
+10,ST001,Downtown Store,Retail Corp,New York,NY
+11,ST002,Mall Store,Retail Corp,Los Angeles,CA
+12,ST003,Uptown Market,Market Group,Chicago,IL
\ No newline at end of file
diff --git a/ballista/client/testdata/bug_1296/store_sales.csv 
b/ballista/client/testdata/bug_1296/store_sales.csv
new file mode 100644
index 00000000..f2e983e5
--- /dev/null
+++ b/ballista/client/testdata/bug_1296/store_sales.csv
@@ -0,0 +1,16 @@
+ss_item_sk,ss_sold_date_sk,ss_store_sk,ss_quantity,ss_sales_price
+20,1,10,2,999.99
+20,2,10,3,999.99
+22,2,11,10,19.99
+20,3,10,4,999.99
+20,3,10,1,999.99
+22,3,11,8,19.99
+20,4,10,9,999.99
+20,4,10,1,999.99
+22,4,11,15,19.99
+20,5,10,5,999.99
+22,5,11,9,19.99
+20,6,10,4,999.99
+22,6,11,8,19.99
+20,9,10,3,999.99
+22,9,11,10,19.99
\ No newline at end of file
diff --git a/ballista/client/tests/bugs.rs b/ballista/client/tests/bugs.rs
new file mode 100644
index 00000000..02b21da2
--- /dev/null
+++ b/ballista/client/tests/bugs.rs
@@ -0,0 +1,147 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod common;
+
+// this tests require changes of RUST_MIN_STACK before run
+//
+// ```
+// export RUST_MIN_STACK=20971520
+// cargo test --features test_extended_stack --test bugs
+// ```
+#[cfg(test)]
+#[cfg(feature = "standalone")]
+#[cfg(feature = "test_extended_stack")]
+mod extended {
+
+    use ballista::prelude::SessionContextExt;
+    use datafusion::{assert_batches_eq, prelude::*};
+    //
+    // tests bug: [Failed to execute TPC-DS Q47 in 
ballista](https://github.com/apache/datafusion-ballista/issues/1296)
+    // where query execution fails due to missing sort information
+    //
+    #[tokio::test]
+    async fn bug_1296_basic() -> datafusion::error::Result<()> {
+        let test_data = crate::common::example_test_data();
+
+        let ctx = SessionContext::standalone().await?;
+
+        ctx.register_csv(
+            "date_dim",
+            &format!("{}/bug_1296/date_dim.csv", test_data),
+            CsvReadOptions::new().has_header(true),
+        )
+        .await?;
+
+        ctx.register_csv(
+            "store",
+            &format!("{}/bug_1296/store.csv", test_data),
+            CsvReadOptions::new().has_header(true),
+        )
+        .await?;
+
+        ctx.register_csv(
+            "item",
+            &format!("{}/bug_1296/item.csv", test_data),
+            CsvReadOptions::new().has_header(true),
+        )
+        .await?;
+
+        ctx.register_csv(
+            "store_sales",
+            &format!("{}/bug_1296/store_sales.csv", test_data),
+            CsvReadOptions::new().has_header(true),
+        )
+        .await?;
+
+        let query = r#"
+with v1 as(
+    select i_category,
+           i_brand,
+           s_store_name,
+           s_company_name,
+           d_year,
+           d_moy,
+           sum(ss_sales_price) sum_sales,
+           avg(sum(ss_sales_price)) over
+        (partition by i_category, i_brand,
+                     s_store_name, s_company_name, d_year)
+          avg_monthly_sales,
+            rank() over
+        (partition by i_category, i_brand,
+                   s_store_name, s_company_name
+         order by d_year, d_moy) rn
+    from item, store_sales, date_dim, store
+    where ss_item_sk = i_item_sk and
+            ss_sold_date_sk = d_date_sk
+      and ss_store_sk = s_store_sk and
+        (
+                    d_year = 1999 or
+                    ( d_year = 1999-1 and d_moy =12) or
+                    ( d_year = 1999+1 and d_moy =1)
+            )
+    group by i_category, i_brand,
+             s_store_name, s_company_name,
+             d_year, d_moy),
+     v2 as(
+         select v1.i_category,
+                v1.i_brand,
+                v1.s_store_name,
+                v1.s_company_name,
+                v1.d_year,
+                v1.d_moy,
+                v1.avg_monthly_sales,
+                v1.sum_sales,
+                v1_lag.sum_sales psum,
+                v1_lead.sum_sales nsum
+         from v1, v1 v1_lag, v1 v1_lead
+         where v1.i_category = v1_lag.i_category and
+                 v1.i_category = v1_lead.i_category and
+                 v1.i_brand = v1_lag.i_brand and
+                 v1.i_brand = v1_lead.i_brand and
+                 v1.s_store_name = v1_lag.s_store_name and
+                 v1.s_store_name = v1_lead.s_store_name and
+                 v1.s_company_name = v1_lag.s_company_name and
+                 v1.s_company_name = v1_lead.s_company_name and
+                 v1.rn = v1_lag.rn + 1 and
+                 v1.rn = v1_lead.rn - 1)
+select  *
+from v2
+where  d_year = 1999 and
+        avg_monthly_sales > 0 and
+        case when avg_monthly_sales > 0 then abs(sum_sales - 
avg_monthly_sales) / avg_monthly_sales else null end > 0.1
+order by sum_sales - avg_monthly_sales, s_store_name
+    limit 100;
+    "#;
+
+        let result = ctx.sql(query).await?.collect().await?;
+        let expected = [
+    
"+-------------+-----------+----------------+----------------+--------+-------+--------------------+-----------+---------+---------+",
+    "| i_category  | i_brand   | s_store_name   | s_company_name | d_year | 
d_moy | avg_monthly_sales  | sum_sales | psum    | nsum    |",
+    
"+-------------+-----------+----------------+----------------+--------+-------+--------------------+-----------+---------+---------+",
+    "| Electronics | TechBrand | Downtown Store | Retail Corp    | 1999   | 4  
   | 1499.9850000000001 | 999.99    | 999.99  | 999.99  |",
+    "| Electronics | TechBrand | Downtown Store | Retail Corp    | 1999   | 3  
   | 1499.9850000000001 | 999.99    | 1999.98 | 999.99  |",
+    "| Electronics | TechBrand | Downtown Store | Retail Corp    | 1999   | 1  
   | 1499.9850000000001 | 1999.98   | 1999.98 | 1999.98 |",
+    "| Electronics | TechBrand | Downtown Store | Retail Corp    | 1999   | 2  
   | 1499.9850000000001 | 1999.98   | 1999.98 | 999.99  |",
+    
"+-------------+-----------+----------------+----------------+--------+-------+--------------------+-----------+---------+---------+",
+        ];
+
+        assert_batches_eq!(expected, &result);
+
+        Ok(())
+    }
+}
diff --git a/ballista/scheduler/src/planner.rs 
b/ballista/scheduler/src/planner.rs
index e11b5147..daf4d971 100644
--- a/ballista/scheduler/src/planner.rs
+++ b/ballista/scheduler/src/planner.rs
@@ -25,6 +25,9 @@ use ballista_core::{
     execution_plans::{ShuffleReaderExec, ShuffleWriterExec, 
UnresolvedShuffleExec},
     serde::scheduler::PartitionLocation,
 };
+use datafusion::config::ConfigOptions;
+use datafusion::physical_optimizer::enforce_sorting::EnforceSorting;
+use datafusion::physical_optimizer::PhysicalOptimizerRule;
 use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
 use datafusion::physical_plan::repartition::RepartitionExec;
 use 
datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
@@ -45,16 +48,24 @@ pub trait DistributedPlanner {
         &'a mut self,
         job_id: &'a str,
         execution_plan: Arc<dyn ExecutionPlan>,
+        config: &ConfigOptions,
     ) -> Result<Vec<Arc<ShuffleWriterExec>>>;
 }
 /// Default implementation of [DistributedPlanner]
 pub struct DefaultDistributedPlanner {
     next_stage_id: usize,
+    optimizer_enforce_sorting: EnforceSorting,
 }
 
 impl DefaultDistributedPlanner {
     pub fn new() -> Self {
-        Self { next_stage_id: 0 }
+        Self {
+            next_stage_id: 0,
+            // when plan is broken into stages some sorting information may 
get lost in the process
+            // thus stage re-optimisation is needed to adjust sort information
+            optimizer_enforce_sorting:
+                
datafusion::physical_optimizer::enforce_sorting::EnforceSorting::default(),
+        }
     }
 }
 
@@ -72,10 +83,11 @@ impl DistributedPlanner for DefaultDistributedPlanner {
         &'a mut self,
         job_id: &'a str,
         execution_plan: Arc<dyn ExecutionPlan>,
+        config: &ConfigOptions,
     ) -> Result<Vec<Arc<ShuffleWriterExec>>> {
         info!("planning query stages for job {job_id}");
         let (new_plan, mut stages) =
-            self.plan_query_stages_internal(job_id, execution_plan)?;
+            self.plan_query_stages_internal(job_id, execution_plan, config)?;
         stages.push(create_shuffle_writer(
             job_id,
             self.next_stage_id(),
@@ -94,6 +106,7 @@ impl DefaultDistributedPlanner {
         &'a mut self,
         job_id: &'a str,
         execution_plan: Arc<dyn ExecutionPlan>,
+        config: &ConfigOptions,
     ) -> Result<PartialQueryStageResult> {
         // recurse down and replace children
         if execution_plan.children().is_empty() {
@@ -104,7 +117,7 @@ impl DefaultDistributedPlanner {
         let mut children = vec![];
         for child in execution_plan.children() {
             let (new_child, mut child_stages) =
-                self.plan_query_stages_internal(job_id, child.clone())?;
+                self.plan_query_stages_internal(job_id, child.clone(), 
config)?;
             children.push(new_child);
             stages.append(&mut child_stages);
         }
@@ -113,13 +126,12 @@ impl DefaultDistributedPlanner {
             .as_any()
             .downcast_ref::<CoalescePartitionsExec>()
         {
-            let shuffle_writer = create_shuffle_writer(
-                job_id,
-                self.next_stage_id(),
-                children[0].clone(),
-                None,
-            )?;
+            let input = children[0].clone();
+            let input = self.optimizer_enforce_sorting.optimize(input, 
config)?;
+            let shuffle_writer =
+                create_shuffle_writer(job_id, self.next_stage_id(), input, 
None)?;
             let unresolved_shuffle = 
create_unresolved_shuffle(&shuffle_writer);
+
             stages.push(shuffle_writer);
             Ok((
                 with_new_children_if_necessary(execution_plan, 
vec![unresolved_shuffle])?,
@@ -146,13 +158,17 @@ impl DefaultDistributedPlanner {
         {
             match repart.properties().output_partitioning() {
                 Partitioning::Hash(_, _) => {
+                    let input = children[0].clone();
+                    let input = self.optimizer_enforce_sorting.optimize(input, 
config)?;
+
                     let shuffle_writer = create_shuffle_writer(
                         job_id,
                         self.next_stage_id(),
-                        children[0].clone(),
+                        input,
                         Some(repart.partitioning().to_owned()),
                     )?;
                     let unresolved_shuffle = 
create_unresolved_shuffle(&shuffle_writer);
+
                     stages.push(shuffle_writer);
                     Ok((unresolved_shuffle, stages))
                 }
@@ -362,7 +378,11 @@ mod test {
 
         let mut planner = DefaultDistributedPlanner::new();
         let job_uuid = Uuid::new_v4();
-        let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
+        let stages = planner.plan_query_stages(
+            &job_uuid.to_string(),
+            plan,
+            ctx.state().config().options(),
+        )?;
         for (i, stage) in stages.iter().enumerate() {
             println!("Stage {i}:\n{}", 
displayable(stage.as_ref()).indent(false));
         }
@@ -476,7 +496,11 @@ order by
 
         let mut planner = DefaultDistributedPlanner::new();
         let job_uuid = Uuid::new_v4();
-        let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
+        let stages = planner.plan_query_stages(
+            &job_uuid.to_string(),
+            plan,
+            ctx.state().config().options(),
+        )?;
         for (i, stage) in stages.iter().enumerate() {
             println!("Stage {i}:\n{}", 
displayable(stage.as_ref()).indent(false));
         }
@@ -644,7 +668,11 @@ order by
 
         let mut planner = DefaultDistributedPlanner::new();
         let job_uuid = Uuid::new_v4();
-        let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
+        let stages = planner.plan_query_stages(
+            &job_uuid.to_string(),
+            plan,
+            ctx.state().config().options(),
+        )?;
         for (i, stage) in stages.iter().enumerate() {
             println!("Stage {i}:\n{}", 
displayable(stage.as_ref()).indent(false));
         }
@@ -754,7 +782,11 @@ order by
 
         let mut planner = DefaultDistributedPlanner::new();
         let job_uuid = Uuid::new_v4();
-        let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
+        let stages = planner.plan_query_stages(
+            &job_uuid.to_string(),
+            plan,
+            ctx.state().config().options(),
+        )?;
 
         let partial_hash = stages[0].children()[0].clone();
         let partial_hash_serde = roundtrip_operator(&ctx, 
partial_hash.clone())?;
diff --git a/ballista/scheduler/src/state/execution_graph.rs 
b/ballista/scheduler/src/state/execution_graph.rs
index b42c3806..c5cfbe9d 100644
--- a/ballista/scheduler/src/state/execution_graph.rs
+++ b/ballista/scheduler/src/state/execution_graph.rs
@@ -150,7 +150,8 @@ impl ExecutionGraph {
         planner: &mut dyn DistributedPlanner,
     ) -> Result<Self> {
         let output_partitions = 
plan.properties().output_partitioning().partition_count();
-        let shuffle_stages = planner.plan_query_stages(job_id, plan)?;
+        let shuffle_stages =
+            planner.plan_query_stages(job_id, plan, session_config.options())?;
 
         let builder = ExecutionStageBuilder::new(session_config.clone());
         let stages = builder.build(shuffle_stages)?;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to