adriangb commented on code in PR #17337: URL: https://github.com/apache/datafusion/pull/17337#discussion_r2429327177
########## datafusion/optimizer/src/push_down_sort.rs: ########## @@ -0,0 +1,580 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`PushDownSort`] pushes sort expressions into table scans to enable +//! sort pushdown optimizations by table providers + +use std::sync::Arc; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use datafusion_common::tree_node::Transformed; +use datafusion_common::Result; +use datafusion_expr::logical_plan::{LogicalPlan, TableScan}; +use datafusion_expr::{Expr, SortExpr}; + +/// Optimization rule that pushes sort expressions down to table scans +/// when the sort can potentially be optimized by the table provider. +/// +/// This rule looks for `Sort -> TableScan` patterns and moves the sort +/// expressions into the `TableScan.preferred_ordering` field, allowing +/// table providers to potentially optimize the scan based on sort requirements. +/// +/// # Behavior +/// +/// The optimizer preserves the original `Sort` node as a fallback while passing +/// the ordering preference to the `TableScan` as an optimization hint. This ensures +/// correctness even if the table provider cannot satisfy the requested ordering. +/// +/// # Supported Sort Expressions +/// +/// Currently, only simple column references are supported for pushdown because +/// table providers typically cannot optimize complex expressions in sort operations. +/// Complex expressions like `col("a") + col("b")` or function calls are not pushed down. +/// +/// # Examples +/// +/// ```text +/// Before optimization: +/// Sort: test.a ASC NULLS LAST +/// TableScan: test +/// +/// After optimization: +/// Sort: test.a ASC NULLS LAST -- Preserved as fallback +/// TableScan: test -- Now includes preferred_ordering hint +/// ``` +#[derive(Default, Debug)] +pub struct PushDownSort {} + +impl PushDownSort { + /// Creates a new instance of the `PushDownSort` optimizer rule. + /// + /// # Returns + /// + /// A new `PushDownSort` optimizer rule that can be added to the optimization pipeline. + /// + /// # Examples + /// + /// ```rust + /// use datafusion_optimizer::push_down_sort::PushDownSort; + /// + /// let rule = PushDownSort::new(); + /// ``` + pub fn new() -> Self { + Self {} + } + + /// Checks if a sort expression can be pushed down to a table scan. + /// + /// Currently, we only support pushing down simple column references + /// because table providers typically can't optimize complex expressions + /// in sort pushdown. + fn can_pushdown_sort_expr(expr: &SortExpr) -> bool { + // Only push down simple column references + matches!(expr.expr, Expr::Column(_)) + } + + /// Checks if all sort expressions in a list can be pushed down. + fn can_pushdown_sort_exprs(sort_exprs: &[SortExpr]) -> bool { + sort_exprs.iter().all(Self::can_pushdown_sort_expr) + } +} + +impl OptimizerRule for PushDownSort { + fn supports_rewrite(&self) -> bool { + true + } + + fn apply_order(&self) -> Option<ApplyOrder> { + Some(ApplyOrder::TopDown) + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result<Transformed<LogicalPlan>> { + // Look for Sort -> TableScan pattern + let LogicalPlan::Sort(sort) = &plan else { + return Ok(Transformed::no(plan)); + }; + + let LogicalPlan::TableScan(table_scan) = sort.input.as_ref() else { + return Ok(Transformed::no(plan)); + }; + + // Check if we can push down the sort expressions + if !Self::can_pushdown_sort_exprs(&sort.expr) { + return Ok(Transformed::no(plan)); + } + + // If the table scan already has preferred ordering, don't overwrite it + // This preserves any existing sort preferences from other optimizations + if table_scan.preferred_ordering.is_some() { + return Ok(Transformed::no(plan)); + } + + // Create new TableScan with preferred ordering + let new_table_scan = TableScan { + table_name: table_scan.table_name.clone(), + source: Arc::clone(&table_scan.source), + projection: table_scan.projection.clone(), + projected_schema: Arc::clone(&table_scan.projected_schema), + filters: table_scan.filters.clone(), + fetch: table_scan.fetch, + preferred_ordering: Some(sort.expr.clone()), + }; + + // Preserve the Sort node as a fallback while passing the ordering + // preference to the TableScan as an optimization hint + let new_sort = datafusion_expr::logical_plan::Sort { + expr: sort.expr.clone(), + input: Arc::new(LogicalPlan::TableScan(new_table_scan)), + fetch: sort.fetch, + }; + let new_plan = LogicalPlan::Sort(new_sort); + + Ok(Transformed::yes(new_plan)) + } + + fn name(&self) -> &str { + "push_down_sort" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::test_table_scan; + use crate::{assert_optimized_plan_eq_snapshot, OptimizerContext}; + use datafusion_common::{Column, Result}; + use datafusion_expr::{col, lit, Expr, JoinType, LogicalPlanBuilder, SortExpr}; + use std::sync::Arc; + + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownSort::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; + } + + #[test] + fn test_can_pushdown_sort_expr() { + // Simple column reference should be pushable + let sort_expr = SortExpr::new(col("a"), true, false); + assert!(PushDownSort::can_pushdown_sort_expr(&sort_expr)); + + // Complex expression should not be pushable + let sort_expr = SortExpr::new(col("a") + col("b"), true, false); + assert!(!PushDownSort::can_pushdown_sort_expr(&sort_expr)); + + // Function call should not be pushable + let sort_expr = SortExpr::new(col("c").like(lit("test%")), true, false); + assert!(!PushDownSort::can_pushdown_sort_expr(&sort_expr)); + + // Literal should not be pushable + let sort_expr = SortExpr::new(lit(42), true, false); + assert!(!PushDownSort::can_pushdown_sort_expr(&sort_expr)); + } + + #[test] + fn test_can_pushdown_sort_exprs() { + // All simple columns should be pushable + let sort_exprs = vec![ + SortExpr::new(col("a"), true, false), + SortExpr::new(col("b"), false, true), + ]; + assert!(PushDownSort::can_pushdown_sort_exprs(&sort_exprs)); + + // Mix of simple and complex should not be pushable + let sort_exprs = vec![ + SortExpr::new(col("a"), true, false), + SortExpr::new(col("a") + col("b"), false, true), + ]; + assert!(!PushDownSort::can_pushdown_sort_exprs(&sort_exprs)); + + // Empty list should be pushable + let sort_exprs = vec![]; + assert!(PushDownSort::can_pushdown_sort_exprs(&sort_exprs)); + } + + #[test] + fn test_basic_sort_pushdown_to_table_scan() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .sort(vec![SortExpr::new(col("a"), true, false)])? + .build()?; + + // Sort node is preserved with preferred_ordering passed to TableScan + assert_optimized_plan_equal!( + plan, + @ r" + Sort: test.a ASC NULLS LAST + TableScan: test + " + ) + } + + #[test] + fn test_multiple_column_sort_pushdown() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .sort(vec![ + SortExpr::new(col("a"), true, false), + SortExpr::new(col("b"), false, true), + ])? + .build()?; + + // Multi-column sort is preserved with preferred_ordering passed to TableScan + assert_optimized_plan_equal!( + plan, + @ r" + Sort: test.a ASC NULLS LAST, test.b DESC NULLS FIRST + TableScan: test + " + ) + } + + #[test] + fn test_sort_node_preserved_with_preferred_ordering() -> Result<()> { + let rule = PushDownSort::new(); + let table_scan = test_table_scan()?; + let sort_plan = LogicalPlanBuilder::from(table_scan) + .sort(vec![SortExpr::new(col("a"), true, false)])? + .build()?; + + let config = &OptimizerContext::new(); + let result = rule.rewrite(sort_plan, config)?; + + // Verify Sort node is preserved + match &result.data { + LogicalPlan::Sort(sort) => { + // Check that TableScan has preferred_ordering + if let LogicalPlan::TableScan(ts) = sort.input.as_ref() { + assert!(ts.preferred_ordering.is_some()); + } else { + panic!("Expected TableScan input"); + } + } + _ => panic!("Expected Sort node to be preserved"), + } + + Ok(()) + } + + #[test] + fn test_no_pushdown_with_complex_expressions() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .sort(vec![ + SortExpr::new(col("a"), true, false), + SortExpr::new(col("a") + col("b"), false, true), // Complex expression + ])? + .build()?; + + // Sort should remain unchanged + assert_optimized_plan_equal!( + plan, + @ r" + Sort: test.a ASC NULLS LAST, test.a + test.b DESC NULLS FIRST + TableScan: test + " + ) + } + + #[test] + fn test_no_pushdown_through_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .sort(vec![SortExpr::new(col("a"), true, false)])? + .build()?; + + // Sort should remain above projection + assert_optimized_plan_equal!( + plan, + @ r" + Sort: test.a ASC NULLS LAST + Projection: test.a, test.b + TableScan: test + " + ) + } + + #[test] + fn test_no_pushdown_through_filter() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("a").gt(lit(10)))? + .sort(vec![SortExpr::new(col("a"), true, false)])? + .build()?; + + // Sort should remain above filter + assert_optimized_plan_equal!( + plan, + @ r" + Sort: test.a ASC NULLS LAST + Filter: test.a > Int32(10) + TableScan: test Review Comment: resolving since I reworked most of it and added new tests for sort -> filter -> scan -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
