This is an automated email from the ASF dual-hosted git repository.

berkay pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 7a4577e963 perf: Introduce sort prefix computation for early TopK exit 
optimization on partially sorted input (10x speedup on top10 bench) (#15563)
7a4577e963 is described below

commit 7a4577e963fec6a9af6028fd932b003352141392
Author: Geoffrey Claude <geoffrey.cla...@datadoghq.com>
AuthorDate: Wed Apr 9 12:35:40 2025 +0200

    perf: Introduce sort prefix computation for early TopK exit optimization on 
partially sorted input (10x speedup on top10 bench) (#15563)
    
    * perf: Introduce sort prefix computation for early TopK exit optimization 
on partially sorted input
    
    * perf: Use same `common_sort_prefix` nomenclature everywhere
    
    * perf: Remove redundant argument to `sort_partially_satisfied`
    
    * perf: Clarify that the `common_sort_prefix` is normalized
    
    * perf: Update the topk tests for normalized projections
    
    * perf: Rename `worst` to `max` to keep naming consistent with heap 
nomenclature
    
    * perf: Add `NULLS FIRST` and `NULLS LAST` TopK sql logic tests
    
    * perf: Rename sqllogic topk test columns and reduce batch size
    
    * perf: Update TopK header doc with "Partial Sort Optimization" section
    
    * fix: Reset `SortExec`'s `EmissionType` to `Final` on partially sorted 
input
    
    - Without a fetch, the entire input data must be sorted before emitting 
results
    - With a fetch, we can optimize for an early exit, but the results will 
still be emitted once all the necessary input data has been processed
---
 .../tests/physical_optimizer/enforce_sorting.rs    |   2 +-
 .../src/equivalence/properties/mod.rs              |  59 +++--
 datafusion/physical-plan/src/sorts/sort.rs         |  45 +++-
 datafusion/physical-plan/src/topk/mod.rs           | 267 +++++++++++++++++++--
 datafusion/sqllogictest/test_files/topk.slt        | 162 +++++++++++++
 datafusion/sqllogictest/test_files/window.slt      |   2 +-
 6 files changed, 489 insertions(+), 48 deletions(-)

diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs 
b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs
index 4d2c875d3f..d4b84a52f4 100644
--- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs
+++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs
@@ -1652,7 +1652,7 @@ async fn test_remove_unnecessary_sort7() -> Result<()> {
     ) as Arc<dyn ExecutionPlan>;
 
     let expected_input = [
-        "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], 
preserve_partitioning=[false]",
+        "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], 
preserve_partitioning=[false], sort_prefix=[non_nullable_col@1 ASC]",
         "  SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], 
preserve_partitioning=[false]",
         "    DataSourceExec: partitions=1, partition_sizes=[0]",
     ];
diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs 
b/datafusion/physical-expr/src/equivalence/properties/mod.rs
index 9cf9897001..5b34a02a91 100644
--- a/datafusion/physical-expr/src/equivalence/properties/mod.rs
+++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs
@@ -546,22 +546,26 @@ impl EquivalenceProperties {
         self.ordering_satisfy_requirement(&sort_requirements)
     }
 
-    /// Checks whether the given sort requirements are satisfied by any of the
-    /// existing orderings.
-    pub fn ordering_satisfy_requirement(&self, reqs: &LexRequirement) -> bool {
-        let mut eq_properties = self.clone();
-        // First, standardize the given requirement:
-        let normalized_reqs = eq_properties.normalize_sort_requirements(reqs);
-
+    /// Returns the number of consecutive requirements (starting from the left)
+    /// that are satisfied by the plan ordering.
+    fn compute_common_sort_prefix_length(
+        &self,
+        normalized_reqs: &LexRequirement,
+    ) -> usize {
         // Check whether given ordering is satisfied by constraints first
-        if self.satisfied_by_constraints(&normalized_reqs) {
-            return true;
+        if self.satisfied_by_constraints(normalized_reqs) {
+            // If the constraints satisfy all requirements, return the full 
normalized requirements length
+            return normalized_reqs.len();
         }
 
-        for normalized_req in normalized_reqs {
+        let mut eq_properties = self.clone();
+
+        for (i, normalized_req) in normalized_reqs.iter().enumerate() {
             // Check whether given ordering is satisfied
-            if !eq_properties.ordering_satisfy_single(&normalized_req) {
-                return false;
+            if !eq_properties.ordering_satisfy_single(normalized_req) {
+                // As soon as one requirement is not satisfied, return
+                // how many we've satisfied so far
+                return i;
             }
             // Treat satisfied keys as constants in subsequent iterations. We
             // can do this because the "next" key only matters in a 
lexicographical
@@ -575,10 +579,35 @@ impl EquivalenceProperties {
             // From the analysis above, we know that `[a ASC]` is satisfied. 
Then,
             // we add column `a` as constant to the algorithm state. This 
enables us
             // to deduce that `(b + c) ASC` is satisfied, given `a` is 
constant.
-            eq_properties = eq_properties
-                
.with_constants(std::iter::once(ConstExpr::from(normalized_req.expr)));
+            eq_properties = eq_properties.with_constants(std::iter::once(
+                ConstExpr::from(Arc::clone(&normalized_req.expr)),
+            ));
         }
-        true
+
+        // All requirements are satisfied.
+        normalized_reqs.len()
+    }
+
+    /// Determines the longest prefix of `reqs` that is satisfied by the 
existing ordering.
+    /// Returns that prefix as a new `LexRequirement`, and a boolean 
indicating if all the requirements are satisfied.
+    pub fn extract_common_sort_prefix(
+        &self,
+        reqs: &LexRequirement,
+    ) -> (LexRequirement, bool) {
+        // First, standardize the given requirement:
+        let normalized_reqs = self.normalize_sort_requirements(reqs);
+
+        let prefix_len = 
self.compute_common_sort_prefix_length(&normalized_reqs);
+        (
+            LexRequirement::new(normalized_reqs[..prefix_len].to_vec()),
+            prefix_len == normalized_reqs.len(),
+        )
+    }
+
+    /// Checks whether the given sort requirements are satisfied by any of the
+    /// existing orderings.
+    pub fn ordering_satisfy_requirement(&self, reqs: &LexRequirement) -> bool {
+        self.extract_common_sort_prefix(reqs).1
     }
 
     /// Checks if the sort requirements are satisfied by any of the table 
constraints (primary key or unique).
diff --git a/datafusion/physical-plan/src/sorts/sort.rs 
b/datafusion/physical-plan/src/sorts/sort.rs
index c0bae77613..5cc7512f88 100644
--- a/datafusion/physical-plan/src/sorts/sort.rs
+++ b/datafusion/physical-plan/src/sorts/sort.rs
@@ -928,6 +928,8 @@ pub struct SortExec {
     preserve_partitioning: bool,
     /// Fetch highest/lowest n results
     fetch: Option<usize>,
+    /// Normalized common sort prefix between the input and the sort 
expressions (only used with fetch)
+    common_sort_prefix: LexOrdering,
     /// Cache holding plan properties like equivalences, output partitioning 
etc.
     cache: PlanProperties,
 }
@@ -937,13 +939,15 @@ impl SortExec {
     /// sorted output partition.
     pub fn new(expr: LexOrdering, input: Arc<dyn ExecutionPlan>) -> Self {
         let preserve_partitioning = false;
-        let cache = Self::compute_properties(&input, expr.clone(), 
preserve_partitioning);
+        let (cache, sort_prefix) =
+            Self::compute_properties(&input, expr.clone(), 
preserve_partitioning);
         Self {
             expr,
             input,
             metrics_set: ExecutionPlanMetricsSet::new(),
             preserve_partitioning,
             fetch: None,
+            common_sort_prefix: sort_prefix,
             cache,
         }
     }
@@ -995,6 +999,7 @@ impl SortExec {
             expr: self.expr.clone(),
             metrics_set: self.metrics_set.clone(),
             preserve_partitioning: self.preserve_partitioning,
+            common_sort_prefix: self.common_sort_prefix.clone(),
             fetch,
             cache,
         }
@@ -1028,19 +1033,21 @@ impl SortExec {
     }
 
     /// This function creates the cache object that stores the plan properties 
such as schema, equivalence properties, ordering, partitioning, etc.
+    /// It also returns the common sort prefix between the input and the sort 
expressions.
     fn compute_properties(
         input: &Arc<dyn ExecutionPlan>,
         sort_exprs: LexOrdering,
         preserve_partitioning: bool,
-    ) -> PlanProperties {
+    ) -> (PlanProperties, LexOrdering) {
         // Determine execution mode:
         let requirement = LexRequirement::from(sort_exprs);
-        let sort_satisfied = input
+
+        let (sort_prefix, sort_satisfied) = input
             .equivalence_properties()
-            .ordering_satisfy_requirement(&requirement);
+            .extract_common_sort_prefix(&requirement);
 
         // The emission type depends on whether the input is already sorted:
-        // - If already sorted, we can emit results in the same way as the 
input
+        // - If already fully sorted, we can emit results in the same way as 
the input
         // - If not sorted, we must wait until all data is processed to emit 
results (Final)
         let emission_type = if sort_satisfied {
             input.pipeline_behavior()
@@ -1076,11 +1083,14 @@ impl SortExec {
         let output_partitioning =
             Self::output_partitioning_helper(input, preserve_partitioning);
 
-        PlanProperties::new(
-            eq_properties,
-            output_partitioning,
-            emission_type,
-            boundedness,
+        (
+            PlanProperties::new(
+                eq_properties,
+                output_partitioning,
+                emission_type,
+                boundedness,
+            ),
+            LexOrdering::from(sort_prefix),
         )
     }
 }
@@ -1092,7 +1102,12 @@ impl DisplayAs for SortExec {
                 let preserve_partitioning = self.preserve_partitioning;
                 match self.fetch {
                     Some(fetch) => {
-                        write!(f, "SortExec: TopK(fetch={fetch}), expr=[{}], 
preserve_partitioning=[{preserve_partitioning}]", self.expr)
+                        write!(f, "SortExec: TopK(fetch={fetch}), expr=[{}], 
preserve_partitioning=[{preserve_partitioning}]", self.expr)?;
+                        if !self.common_sort_prefix.is_empty() {
+                            write!(f, ", sort_prefix=[{}]", 
self.common_sort_prefix)
+                        } else {
+                            Ok(())
+                        }
                     }
                     None => write!(f, "SortExec: expr=[{}], 
preserve_partitioning=[{preserve_partitioning}]", self.expr),
                 }
@@ -1168,10 +1183,12 @@ impl ExecutionPlan for SortExec {
 
         trace!("End SortExec's input.execute for partition: {}", partition);
 
+        let requirement = &LexRequirement::from(self.expr.clone());
+
         let sort_satisfied = self
             .input
             .equivalence_properties()
-            
.ordering_satisfy_requirement(&LexRequirement::from(self.expr.clone()));
+            .ordering_satisfy_requirement(requirement);
 
         match (sort_satisfied, self.fetch.as_ref()) {
             (true, Some(fetch)) => Ok(Box::pin(LimitStream::new(
@@ -1185,6 +1202,7 @@ impl ExecutionPlan for SortExec {
                 let mut topk = TopK::try_new(
                     partition,
                     input.schema(),
+                    self.common_sort_prefix.clone(),
                     self.expr.clone(),
                     *fetch,
                     context.session_config().batch_size(),
@@ -1197,6 +1215,9 @@ impl ExecutionPlan for SortExec {
                         while let Some(batch) = input.next().await {
                             let batch = batch?;
                             topk.insert_batch(batch)?;
+                            if topk.finished {
+                                break;
+                            }
                         }
                         topk.emit()
                     })
diff --git a/datafusion/physical-plan/src/topk/mod.rs 
b/datafusion/physical-plan/src/topk/mod.rs
index 85de1eefce..405aa52fe0 100644
--- a/datafusion/physical-plan/src/topk/mod.rs
+++ b/datafusion/physical-plan/src/topk/mod.rs
@@ -29,8 +29,8 @@ use crate::spill::get_record_batch_memory_size;
 use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
 use arrow::array::{Array, ArrayRef, RecordBatch};
 use arrow::datatypes::SchemaRef;
-use datafusion_common::HashMap;
 use datafusion_common::Result;
+use datafusion_common::{internal_datafusion_err, HashMap};
 use datafusion_execution::{
     memory_pool::{MemoryConsumer, MemoryReservation},
     runtime_env::RuntimeEnv,
@@ -70,6 +70,25 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering;
 /// The same answer can be produced by simply keeping track of the top
 /// K=3 elements, reducing the total amount of required buffer memory.
 ///
+/// # Partial Sort Optimization
+///
+/// This implementation additionally optimizes queries where the input is 
already
+/// partially sorted by a common prefix of the requested ordering. Once the 
top K
+/// heap is full, if subsequent rows are guaranteed to be strictly greater (in 
sort
+/// order) on this prefix than the largest row currently stored, the operator
+/// safely terminates early.
+///
+/// ## Example
+///
+/// For input sorted by `(day DESC)`, but not by `timestamp`, a query such as:
+///
+/// ```sql
+/// SELECT day, timestamp FROM sensor ORDER BY day DESC, timestamp DESC LIMIT 
10;
+/// ```
+///
+/// can terminate scanning early once sufficient rows from the latest days 
have been
+/// collected, skipping older data.
+///
 /// # Structure
 ///
 /// This operator tracks the top K items using a `TopKHeap`.
@@ -90,15 +109,43 @@ pub struct TopK {
     scratch_rows: Rows,
     /// stores the top k values and their sort key values, in order
     heap: TopKHeap,
+    /// row converter, for common keys between the sort keys and the input 
ordering
+    common_sort_prefix_converter: Option<RowConverter>,
+    /// Common sort prefix between the input and the sort expressions to allow 
early exit optimization
+    common_sort_prefix: Arc<[PhysicalSortExpr]>,
+    /// If true, indicates that all rows of subsequent batches are guaranteed
+    /// to be greater (by byte order, after row conversion) than the top K,
+    /// which means the top K won't change and the computation can be finished 
early.
+    pub(crate) finished: bool,
+}
+
+// Guesstimate for memory allocation: estimated number of bytes used per row 
in the RowConverter
+const ESTIMATED_BYTES_PER_ROW: usize = 20;
+
+fn build_sort_fields(
+    ordering: &LexOrdering,
+    schema: &SchemaRef,
+) -> Result<Vec<SortField>> {
+    ordering
+        .iter()
+        .map(|e| {
+            Ok(SortField::new_with_options(
+                e.expr.data_type(schema)?,
+                e.options,
+            ))
+        })
+        .collect::<Result<_>>()
 }
 
 impl TopK {
     /// Create a new [`TopK`] that stores the top `k` values, as
     /// defined by the sort expressions in `expr`.
     // TODO: make a builder or some other nicer API
+    #[allow(clippy::too_many_arguments)]
     pub fn try_new(
         partition_id: usize,
         schema: SchemaRef,
+        common_sort_prefix: LexOrdering,
         expr: LexOrdering,
         k: usize,
         batch_size: usize,
@@ -108,35 +155,34 @@ impl TopK {
         let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]"))
             .register(&runtime.memory_pool);
 
-        let expr: Arc<[PhysicalSortExpr]> = expr.into();
-
-        let sort_fields: Vec<_> = expr
-            .iter()
-            .map(|e| {
-                Ok(SortField::new_with_options(
-                    e.expr.data_type(&schema)?,
-                    e.options,
-                ))
-            })
-            .collect::<Result<_>>()?;
+        let sort_fields: Vec<_> = build_sort_fields(&expr, &schema)?;
 
         // TODO there is potential to add special cases for single column sort 
fields
         // to improve performance
         let row_converter = RowConverter::new(sort_fields)?;
-        let scratch_rows = row_converter.empty_rows(
-            batch_size,
-            20 * batch_size, // guesstimate 20 bytes per row
-        );
+        let scratch_rows =
+            row_converter.empty_rows(batch_size, ESTIMATED_BYTES_PER_ROW * 
batch_size);
+
+        let prefix_row_converter = if common_sort_prefix.is_empty() {
+            None
+        } else {
+            let input_sort_fields: Vec<_> =
+                build_sort_fields(&common_sort_prefix, &schema)?;
+            Some(RowConverter::new(input_sort_fields)?)
+        };
 
         Ok(Self {
             schema: Arc::clone(&schema),
             metrics: TopKMetrics::new(metrics, partition_id),
             reservation,
             batch_size,
-            expr,
+            expr: Arc::from(expr),
             row_converter,
             scratch_rows,
             heap: TopKHeap::new(k, batch_size, schema),
+            common_sort_prefix_converter: prefix_row_converter,
+            common_sort_prefix: Arc::from(common_sort_prefix),
+            finished: false,
         })
     }
 
@@ -144,7 +190,8 @@ impl TopK {
     /// the top k seen so far.
     pub fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> {
         // Updates on drop
-        let _timer = self.metrics.baseline.elapsed_compute().timer();
+        let baseline = self.metrics.baseline.clone();
+        let _timer = baseline.elapsed_compute().timer();
 
         let sort_keys: Vec<ArrayRef> = self
             .expr
@@ -163,7 +210,7 @@ impl TopK {
         // TODO make this algorithmically better?:
         // Idea: filter out rows >= self.heap.max() early (before passing to 
`RowConverter`)
         //       this avoids some work and also might be better vectorizable.
-        let mut batch_entry = self.heap.register_batch(batch);
+        let mut batch_entry = self.heap.register_batch(batch.clone());
         for (index, row) in rows.iter().enumerate() {
             match self.heap.max() {
                 // heap has k items, and the new row is greater than the
@@ -183,6 +230,87 @@ impl TopK {
 
         // update memory reservation
         self.reservation.try_resize(self.size())?;
+
+        // flag the topK as finished if we know that all
+        // subsequent batches are guaranteed to be greater (by byte order, 
after row conversion) than the top K,
+        // which means the top K won't change and the computation can be 
finished early.
+        self.attempt_early_completion(&batch)?;
+
+        Ok(())
+    }
+
+    /// If input ordering shares a common sort prefix with the TopK, and if 
the TopK's heap is full,
+    /// check if the computation can be finished early.
+    /// This is the case if the last row of the current batch is strictly 
greater than the max row in the heap,
+    /// comparing only on the shared prefix columns.
+    fn attempt_early_completion(&mut self, batch: &RecordBatch) -> Result<()> {
+        // Early exit if the batch is empty as there is no last row to extract 
from it.
+        if batch.num_rows() == 0 {
+            return Ok(());
+        }
+
+        // prefix_row_converter is only `Some` if the input ordering has a 
common prefix with the TopK,
+        // so early exit if it is `None`.
+        let Some(prefix_converter) = &self.common_sort_prefix_converter else {
+            return Ok(());
+        };
+
+        // Early exit if the heap is not full (`heap.max()` only returns 
`Some` if the heap is full).
+        let Some(max_topk_row) = self.heap.max() else {
+            return Ok(());
+        };
+
+        // Evaluate the prefix for the last row of the current batch.
+        let last_row_idx = batch.num_rows() - 1;
+        let mut batch_prefix_scratch =
+            prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); // 1 row 
with capacity ESTIMATED_BYTES_PER_ROW
+
+        self.compute_common_sort_prefix(batch, last_row_idx, &mut 
batch_prefix_scratch)?;
+
+        // Retrieve the max row from the heap.
+        let store_entry = self
+            .heap
+            .store
+            .get(max_topk_row.batch_id)
+            .ok_or(internal_datafusion_err!("Invalid batch id in topK heap"))?;
+        let max_batch = &store_entry.batch;
+        let mut heap_prefix_scratch =
+            prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); // 1 row 
with capacity ESTIMATED_BYTES_PER_ROW
+        self.compute_common_sort_prefix(
+            max_batch,
+            max_topk_row.index,
+            &mut heap_prefix_scratch,
+        )?;
+
+        // If the last row's prefix is strictly greater than the max prefix, 
mark as finished.
+        if batch_prefix_scratch.row(0).as_ref() > 
heap_prefix_scratch.row(0).as_ref() {
+            self.finished = true;
+        }
+
+        Ok(())
+    }
+
+    // Helper function to compute the prefix for a given batch and row index, 
storing the result in scratch.
+    fn compute_common_sort_prefix(
+        &self,
+        batch: &RecordBatch,
+        last_row_idx: usize,
+        scratch: &mut Rows,
+    ) -> Result<()> {
+        let last_row: Vec<ArrayRef> = self
+            .common_sort_prefix
+            .iter()
+            .map(|expr| {
+                expr.expr
+                    .evaluate(&batch.slice(last_row_idx, 1))?
+                    .into_array(1)
+            })
+            .collect::<Result<_>>()?;
+
+        self.common_sort_prefix_converter
+            .as_ref()
+            .unwrap()
+            .append(scratch, &last_row)?;
         Ok(())
     }
 
@@ -197,6 +325,9 @@ impl TopK {
             row_converter: _,
             scratch_rows: _,
             mut heap,
+            common_sort_prefix_converter: _,
+            common_sort_prefix: _,
+            finished: _,
         } = self;
         let _timer = metrics.baseline.elapsed_compute().timer(); // time 
updated on drop
 
@@ -649,6 +780,10 @@ mod tests {
     use super::*;
     use arrow::array::{Float64Array, Int32Array, RecordBatch};
     use arrow::datatypes::{DataType, Field, Schema};
+    use arrow_schema::SortOptions;
+    use datafusion_common::assert_batches_eq;
+    use datafusion_physical_expr::expressions::col;
+    use futures::TryStreamExt;
 
     /// This test ensures the size calculation is correct for RecordBatches 
with multiple columns.
     #[test]
@@ -681,4 +816,98 @@ mod tests {
         record_batch_store.unuse(0);
         assert_eq!(record_batch_store.batches_size, 0);
     }
+
+    /// This test validates that the `try_finish` method marks the TopK 
operator as finished
+    /// when the prefix (on column "a") of the last row in the current batch 
is strictly greater
+    /// than the max top‑k row.
+    /// The full sort expression is defined on both columns ("a", "b"), but 
the input ordering is only on "a".
+    #[tokio::test]
+    async fn test_try_finish_marks_finished_with_prefix() -> Result<()> {
+        // Create a schema with two columns.
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Int32, false),
+            Field::new("b", DataType::Float64, false),
+        ]));
+
+        // Create sort expressions.
+        // Full sort: first by "a", then by "b".
+        let sort_expr_a = PhysicalSortExpr {
+            expr: col("a", schema.as_ref())?,
+            options: SortOptions::default(),
+        };
+        let sort_expr_b = PhysicalSortExpr {
+            expr: col("b", schema.as_ref())?,
+            options: SortOptions::default(),
+        };
+
+        // Input ordering uses only column "a" (a prefix of the full sort).
+        let input_ordering = LexOrdering::from(vec![sort_expr_a.clone()]);
+        let full_expr = LexOrdering::from(vec![sort_expr_a, sort_expr_b]);
+
+        // Create a dummy runtime environment and metrics.
+        let runtime = Arc::new(RuntimeEnv::default());
+        let metrics = ExecutionPlanMetricsSet::new();
+
+        // Create a TopK instance with k = 3 and batch_size = 2.
+        let mut topk = TopK::try_new(
+            0,
+            Arc::clone(&schema),
+            input_ordering,
+            full_expr,
+            3,
+            2,
+            runtime,
+            &metrics,
+        )?;
+
+        // Create the first batch with two columns:
+        // Column "a": [1, 1, 2], Column "b": [20.0, 15.0, 30.0].
+        let array_a1: ArrayRef =
+            Arc::new(Int32Array::from(vec![Some(1), Some(1), Some(2)]));
+        let array_b1: ArrayRef = Arc::new(Float64Array::from(vec![20.0, 15.0, 
30.0]));
+        let batch1 = RecordBatch::try_new(Arc::clone(&schema), vec![array_a1, 
array_b1])?;
+
+        // Insert the first batch.
+        // At this point the heap is not yet “finished” because the prefix of 
the last row of the batch
+        // is not strictly greater than the prefix of the max top‑k row (both 
being `2`).
+        topk.insert_batch(batch1)?;
+        assert!(
+            !topk.finished,
+            "Expected 'finished' to be false after the first batch."
+        );
+
+        // Create the second batch with two columns:
+        // Column "a": [2, 3], Column "b": [10.0, 20.0].
+        let array_a2: ArrayRef = Arc::new(Int32Array::from(vec![Some(2), 
Some(3)]));
+        let array_b2: ArrayRef = Arc::new(Float64Array::from(vec![10.0, 
20.0]));
+        let batch2 = RecordBatch::try_new(Arc::clone(&schema), vec![array_a2, 
array_b2])?;
+
+        // Insert the second batch.
+        // The last row in this batch has a prefix value of `3`,
+        // which is strictly greater than the max top‑k row (with value `2`),
+        // so try_finish should mark the TopK as finished.
+        topk.insert_batch(batch2)?;
+        assert!(
+            topk.finished,
+            "Expected 'finished' to be true after the second batch."
+        );
+
+        // Verify the TopK correctly emits the top k rows from both batches
+        // (the value 10.0 for b is from the second batch).
+        let results: Vec<_> = topk.emit()?.try_collect().await?;
+        assert_batches_eq!(
+            &[
+                "+---+------+",
+                "| a | b    |",
+                "+---+------+",
+                "| 1 | 15.0 |",
+                "| 1 | 20.0 |",
+                "| 2 | 10.0 |",
+                "+---+------+",
+            ],
+            &results
+        );
+
+        Ok(())
+    }
 }
diff --git a/datafusion/sqllogictest/test_files/topk.slt 
b/datafusion/sqllogictest/test_files/topk.slt
index b5ff95c358..ce23fe2652 100644
--- a/datafusion/sqllogictest/test_files/topk.slt
+++ b/datafusion/sqllogictest/test_files/topk.slt
@@ -233,3 +233,165 @@ d 1 -98 y7C453hRWd4E7ImjNDWlpexB8nUqjh 
y7C453hRWd4E7ImjNDWlpexB8nUqjh
 e 2 52 xipQ93429ksjNcXPX5326VSg1xJZcW xipQ93429ksjNcXPX5326VSg1xJZcW
 d 1 -72 wwXqSGKLyBQyPkonlzBNYUJTCo4LRS wwXqSGKLyBQyPkonlzBNYUJTCo4LRS
 a 1 -5 waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs
+
+#####################################
+## Test TopK with Partially Sorted Inputs
+#####################################
+
+
+# Create an external table where data is pre-sorted by (number DESC, letter 
ASC) only.
+statement ok
+CREATE EXTERNAL TABLE partial_sorted (
+    number INT,
+    letter VARCHAR,
+    age INT
+)
+STORED AS parquet
+LOCATION 'test_files/scratch/topk/partial_sorted/1.parquet'
+WITH ORDER (number DESC, letter ASC);
+
+# Insert test data into the external table.
+query I
+COPY (
+  SELECT *
+  FROM (
+    VALUES
+      (1, 'F', 100),
+      (1, 'B', 50),
+      (2, 'C', 70),
+      (2, 'D', 80),
+      (3, 'A', 60),
+      (3, 'E', 90)
+  ) AS t(number, letter, age)
+  ORDER BY number DESC, letter ASC
+)
+TO 'test_files/scratch/topk/partial_sorted/1.parquet';
+----
+6
+
+## explain physical_plan only
+statement ok
+set datafusion.explain.physical_plan_only = true
+
+## batch size smaller than number of rows in the table and result
+statement ok
+set datafusion.execution.batch_size = 2
+
+# Run a TopK query that orders by all columns.
+# Although the table is only guaranteed to be sorted by (number DESC, letter 
ASC),
+# DataFusion should use the common prefix optimization
+# and return the correct top 3 rows when ordering by all columns.
+query ITI
+select number, letter, age from partial_sorted order by number desc, letter 
asc, age desc limit 3;
+----
+3 A 60
+3 E 90
+2 C 70
+
+# A more complex example with a projection that includes an expression (see 
further down for the explained plan)
+query IIITI
+select
+  number + 1 as number_plus,
+  number,
+  number + 1 as other_number_plus,
+  letter,
+  age
+from partial_sorted
+order by
+  number_plus desc,
+  number desc,
+  other_number_plus desc,
+  letter asc,
+  age desc
+limit 3;
+----
+4 3 4 A 60
+4 3 4 E 90
+3 2 3 C 70
+
+# Verify that the physical plan includes the sort prefix.
+# The output should display a "sort_prefix" in the SortExec node.
+query TT
+explain select number, letter, age from partial_sorted order by number desc, 
letter asc, age desc limit 3;
+----
+physical_plan
+01)SortExec: TopK(fetch=3), expr=[number@0 DESC, letter@1 ASC NULLS LAST, 
age@2 DESC], preserve_partitioning=[false], sort_prefix=[number@0 DESC, 
letter@1 ASC NULLS LAST]
+02)--DataSourceExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]},
 projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC 
NULLS LAST], file_type=parquet
+
+
+# Explain variations of the above query with different orderings, and 
different sort prefixes.
+# The "sort_prefix" in the  SortExec node should only be present if the TopK's 
ordering starts with either (number DESC, letter ASC) or just (number DESC).
+query TT
+explain select number, letter, age from partial_sorted order by age desc limit 
3;
+----
+physical_plan
+01)SortExec: TopK(fetch=3), expr=[age@2 DESC], preserve_partitioning=[false]
+02)--DataSourceExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]},
 projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC 
NULLS LAST], file_type=parquet
+
+query TT
+explain select number, letter, age from partial_sorted order by number desc, 
letter desc limit 3;
+----
+physical_plan
+01)SortExec: TopK(fetch=3), expr=[number@0 DESC, letter@1 DESC], 
preserve_partitioning=[false], sort_prefix=[number@0 DESC]
+02)--DataSourceExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]},
 projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC 
NULLS LAST], file_type=parquet
+
+query TT
+explain select number, letter, age from partial_sorted order by number asc 
limit 3;
+----
+physical_plan
+01)SortExec: TopK(fetch=3), expr=[number@0 ASC NULLS LAST], 
preserve_partitioning=[false]
+02)--DataSourceExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]},
 projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC 
NULLS LAST], file_type=parquet
+
+query TT
+explain select number, letter, age from partial_sorted order by letter asc, 
number desc limit 3;
+----
+physical_plan
+01)SortExec: TopK(fetch=3), expr=[letter@1 ASC NULLS LAST, number@0 DESC], 
preserve_partitioning=[false]
+02)--DataSourceExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]},
 projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC 
NULLS LAST], file_type=parquet
+
+# Explicit NULLS ordering cases (reversing the order of the NULLS on the 
number and letter orderings)
+query TT
+explain select number, letter, age from partial_sorted order by number desc, 
letter asc NULLS FIRST limit 3;
+----
+physical_plan
+01)SortExec: TopK(fetch=3), expr=[number@0 DESC, letter@1 ASC], 
preserve_partitioning=[false], sort_prefix=[number@0 DESC]
+02)--DataSourceExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]},
 projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC 
NULLS LAST], file_type=parquet
+
+query TT
+explain select number, letter, age from partial_sorted order by number desc 
NULLS LAST, letter asc limit 3;
+----
+physical_plan
+01)SortExec: TopK(fetch=3), expr=[number@0 DESC NULLS LAST, letter@1 ASC NULLS 
LAST], preserve_partitioning=[false]
+02)--DataSourceExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]},
 projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC 
NULLS LAST], file_type=parquet
+
+
+# Verify that the sort prefix is correctly computed on the normalized ordering 
(removing redundant aliased columns)
+query TT
+explain select number, letter, age, number as column4, letter as column5 from 
partial_sorted order by number desc, column4 desc, letter asc, column5 asc, age 
desc limit 3;
+----
+physical_plan
+01)SortExec: TopK(fetch=3), expr=[number@0 DESC, column4@3 DESC, letter@1 ASC 
NULLS LAST, column5@4 ASC NULLS LAST, age@2 DESC], 
preserve_partitioning=[false], sort_prefix=[number@0 DESC, letter@1 ASC NULLS 
LAST]
+02)--ProjectionExec: expr=[number@0 as number, letter@1 as letter, age@2 as 
age, number@0 as column4, letter@1 as column5]
+03)----DataSourceExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]},
 projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC 
NULLS LAST], file_type=parquet
+
+# Verify that the sort prefix is correctly computed over normalized, 
order-maintaining projections (number + 1, number, number + 1, age)
+query TT
+explain select number + 1 as number_plus, number, number + 1 as 
other_number_plus, age from partial_sorted order by number_plus desc, number 
desc, other_number_plus desc, age asc limit 3;
+----
+physical_plan
+01)SortPreservingMergeExec: [number_plus@0 DESC, number@1 DESC, 
other_number_plus@2 DESC, age@3 ASC NULLS LAST], fetch=3
+02)--SortExec: TopK(fetch=3), expr=[number_plus@0 DESC, number@1 DESC, 
other_number_plus@2 DESC, age@3 ASC NULLS LAST], preserve_partitioning=[true], 
sort_prefix=[number_plus@0 DESC, number@1 DESC]
+03)----ProjectionExec: expr=[__common_expr_1@0 as number_plus, number@1 as 
number, __common_expr_1@0 as other_number_plus, age@2 as age]
+04)------ProjectionExec: expr=[CAST(number@0 AS Int64) + 1 as __common_expr_1, 
number@0 as number, age@1 as age]
+05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
+06)----------DataSourceExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]},
 projection=[number, age], output_ordering=[number@0 DESC], file_type=parquet
+
+# Cleanup
+statement ok
+DROP TABLE partial_sorted;
+
+statement ok
+set datafusion.explain.physical_plan_only = false
+
+statement ok
+set datafusion.execution.batch_size = 8192
diff --git a/datafusion/sqllogictest/test_files/window.slt 
b/datafusion/sqllogictest/test_files/window.slt
index c5c094cad3..570967621c 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -2356,7 +2356,7 @@ logical_plan
 03)----WindowAggr: windowExpr=[[row_number() ORDER BY [aggregate_test_100.c9 
DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
 04)------TableScan: aggregate_test_100 projection=[c9]
 physical_plan
-01)SortExec: TopK(fetch=5), expr=[rn1@1 ASC NULLS LAST, c9@0 ASC NULLS LAST], 
preserve_partitioning=[false]
+01)SortExec: TopK(fetch=5), expr=[rn1@1 ASC NULLS LAST, c9@0 ASC NULLS LAST], 
preserve_partitioning=[false], sort_prefix=[rn1@1 ASC NULLS LAST]
 02)--ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY 
[aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 
CURRENT ROW@1 as rn1]
 03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 
DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { 
name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE 
BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: 
false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame 
{ units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, 
is_causal: false }], mode= [...]
 04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org


Reply via email to