This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 727b6ff415 Add fetch to `SortPreservingMergeExec` and 
`SortPreservingMergeStream` (#6811)
727b6ff415 is described below

commit 727b6ff41502e276ed7885531a87364a71826a74
Author: Daniël Heres <[email protected]>
AuthorDate: Mon Jul 3 12:29:32 2023 +0200

    Add fetch to `SortPreservingMergeExec` and `SortPreservingMergeStream` 
(#6811)
    
    * Add fetch to sortpreservingmergeexec
    
    * Add fetch to sortpreservingmergeexec
    
    * fmt
    
    * Deserialize
    
    * Fmt
    
    * Fix test
    
    * Fix test
    
    * Fix test
    
    * Fix plan output
    
    * Doc
    
    * Update datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * Extract into method
    
    * Remove from sort enforcement
    
    * Update datafusion/core/src/physical_plan/sorts/merge.rs
    
    Co-authored-by: Mustafa Akur 
<[email protected]>
    
    * Update datafusion/proto/src/physical_plan/mod.rs
    
    Co-authored-by: Mustafa Akur 
<[email protected]>
    
    ---------
    
    Co-authored-by: Daniël Heres <[email protected]>
    Co-authored-by: Andrew Lamb <[email protected]>
    Co-authored-by: Mustafa Akur 
<[email protected]>
---
 .../physical_optimizer/global_sort_selection.rs    |  2 +-
 .../core/src/physical_plan/repartition/mod.rs      |  1 +
 datafusion/core/src/physical_plan/sorts/merge.rs   | 38 ++++++++++++++++++----
 datafusion/core/src/physical_plan/sorts/sort.rs    |  4 +--
 .../physical_plan/sorts/sort_preserving_merge.rs   | 30 ++++++++++++++---
 datafusion/core/tests/sql/explain_analyze.rs       |  2 +-
 .../sqllogictests/test_files/tpch/q10.slt.part     |  2 +-
 .../sqllogictests/test_files/tpch/q11.slt.part     |  2 +-
 .../sqllogictests/test_files/tpch/q13.slt.part     |  2 +-
 .../sqllogictests/test_files/tpch/q16.slt.part     |  2 +-
 .../sqllogictests/test_files/tpch/q2.slt.part      |  2 +-
 .../sqllogictests/test_files/tpch/q3.slt.part      |  2 +-
 .../sqllogictests/test_files/tpch/q9.slt.part      |  2 +-
 .../core/tests/sqllogictests/test_files/union.slt  |  2 +-
 .../core/tests/sqllogictests/test_files/window.slt |  2 +-
 datafusion/proto/proto/datafusion.proto            |  2 ++
 datafusion/proto/src/generated/pbjson.rs           | 19 +++++++++++
 datafusion/proto/src/generated/prost.rs            |  3 ++
 datafusion/proto/src/physical_plan/mod.rs          | 10 +++++-
 19 files changed, 103 insertions(+), 26 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/global_sort_selection.rs 
b/datafusion/core/src/physical_optimizer/global_sort_selection.rs
index 9466297d24..0b9054f89f 100644
--- a/datafusion/core/src/physical_optimizer/global_sort_selection.rs
+++ b/datafusion/core/src/physical_optimizer/global_sort_selection.rs
@@ -70,7 +70,7 @@ impl PhysicalOptimizerRule for GlobalSortSelection {
                                 Arc::new(SortPreservingMergeExec::new(
                                     sort_exec.expr().to_vec(),
                                     Arc::new(sort),
-                                ));
+                                ).with_fetch(sort_exec.fetch()));
                             Some(global_sort)
                         } else {
                             None
diff --git a/datafusion/core/src/physical_plan/repartition/mod.rs 
b/datafusion/core/src/physical_plan/repartition/mod.rs
index 72ff0c3713..85225eb471 100644
--- a/datafusion/core/src/physical_plan/repartition/mod.rs
+++ b/datafusion/core/src/physical_plan/repartition/mod.rs
@@ -497,6 +497,7 @@ impl ExecutionPlan for RepartitionExec {
                 sort_exprs,
                 BaselineMetrics::new(&self.metrics, partition),
                 context.session_config().batch_size(),
+                None,
             )
         } else {
             Ok(Box::pin(RepartitionStream {
diff --git a/datafusion/core/src/physical_plan/sorts/merge.rs 
b/datafusion/core/src/physical_plan/sorts/merge.rs
index d8a3cdef4d..e191c044b9 100644
--- a/datafusion/core/src/physical_plan/sorts/merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/merge.rs
@@ -39,13 +39,14 @@ macro_rules! primitive_merge_helper {
 }
 
 macro_rules! merge_helper {
-    ($t:ty, $sort:ident, $streams:ident, $schema:ident, 
$tracking_metrics:ident, $batch_size:ident) => {{
+    ($t:ty, $sort:ident, $streams:ident, $schema:ident, 
$tracking_metrics:ident, $batch_size:ident, $fetch:ident) => {{
         let streams = FieldCursorStream::<$t>::new($sort, $streams);
         return Ok(Box::pin(SortPreservingMergeStream::new(
             Box::new(streams),
             $schema,
             $tracking_metrics,
             $batch_size,
+            $fetch,
         )));
     }};
 }
@@ -57,17 +58,18 @@ pub(crate) fn streaming_merge(
     expressions: &[PhysicalSortExpr],
     metrics: BaselineMetrics,
     batch_size: usize,
+    fetch: Option<usize>,
 ) -> Result<SendableRecordBatchStream> {
     // Special case single column comparisons with optimized cursor 
implementations
     if expressions.len() == 1 {
         let sort = expressions[0].clone();
         let data_type = sort.expr.data_type(schema.as_ref())?;
         downcast_primitive! {
-            data_type => (primitive_merge_helper, sort, streams, schema, 
metrics, batch_size),
-            DataType::Utf8 => merge_helper!(StringArray, sort, streams, 
schema, metrics, batch_size)
-            DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, 
streams, schema, metrics, batch_size)
-            DataType::Binary => merge_helper!(BinaryArray, sort, streams, 
schema, metrics, batch_size)
-            DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, 
streams, schema, metrics, batch_size)
+            data_type => (primitive_merge_helper, sort, streams, schema, 
metrics, batch_size, fetch),
+            DataType::Utf8 => merge_helper!(StringArray, sort, streams, 
schema, metrics, batch_size, fetch)
+            DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, 
streams, schema, metrics, batch_size, fetch)
+            DataType::Binary => merge_helper!(BinaryArray, sort, streams, 
schema, metrics, batch_size, fetch)
+            DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, 
streams, schema, metrics, batch_size, fetch)
             _ => {}
         }
     }
@@ -78,6 +80,7 @@ pub(crate) fn streaming_merge(
         schema,
         metrics,
         batch_size,
+        fetch,
     )))
 }
 
@@ -140,6 +143,12 @@ struct SortPreservingMergeStream<C> {
 
     /// Vector that holds cursors for each non-exhausted input partition
     cursors: Vec<Option<C>>,
+
+    /// Optional number of rows to fetch
+    fetch: Option<usize>,
+
+    /// number of rows produced
+    produced: usize,
 }
 
 impl<C: Cursor> SortPreservingMergeStream<C> {
@@ -148,6 +157,7 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
         schema: SchemaRef,
         metrics: BaselineMetrics,
         batch_size: usize,
+        fetch: Option<usize>,
     ) -> Self {
         let stream_count = streams.partitions();
 
@@ -160,6 +170,8 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
             loser_tree: vec![],
             loser_tree_adjusted: false,
             batch_size,
+            fetch,
+            produced: 0,
         }
     }
 
@@ -227,15 +239,27 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
             if self.advance(stream_idx) {
                 self.loser_tree_adjusted = false;
                 self.in_progress.push_row(stream_idx);
-                if self.in_progress.len() < self.batch_size {
+
+                // stop sorting if fetch has been reached
+                if self.fetch_reached() {
+                    self.aborted = true;
+                } else if self.in_progress.len() < self.batch_size {
                     continue;
                 }
             }
 
+            self.produced += self.in_progress.len();
+
             return 
Poll::Ready(self.in_progress.build_record_batch().transpose());
         }
     }
 
+    fn fetch_reached(&mut self) -> bool {
+        self.fetch
+            .map(|fetch| self.produced + self.in_progress.len() >= fetch)
+            .unwrap_or(false)
+    }
+
     fn advance(&mut self, stream_idx: usize) -> bool {
         let slot = &mut self.cursors[stream_idx];
         match slot.as_mut() {
diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs 
b/datafusion/core/src/physical_plan/sorts/sort.rs
index 4983b0ea83..205ec706b5 100644
--- a/datafusion/core/src/physical_plan/sorts/sort.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort.rs
@@ -189,6 +189,7 @@ impl ExternalSorter {
                 &self.expr,
                 self.metrics.baseline.clone(),
                 self.batch_size,
+                self.fetch,
             )
         } else if !self.in_mem_batches.is_empty() {
             let result = 
self.in_mem_sort_stream(self.metrics.baseline.clone());
@@ -285,14 +286,13 @@ impl ExternalSorter {
             })
             .collect::<Result<_>>()?;
 
-        // TODO: Pushdown fetch to streaming merge (#6000)
-
         streaming_merge(
             streams,
             self.schema.clone(),
             &self.expr,
             metrics,
             self.batch_size,
+            self.fetch,
         )
     }
 
diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs 
b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
index 4db1fea2a4..397d254162 100644
--- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -71,6 +71,8 @@ pub struct SortPreservingMergeExec {
     expr: Vec<PhysicalSortExpr>,
     /// Execution metrics
     metrics: ExecutionPlanMetricsSet,
+    /// Optional number of rows to fetch. Stops producing rows after this fetch
+    fetch: Option<usize>,
 }
 
 impl SortPreservingMergeExec {
@@ -80,8 +82,14 @@ impl SortPreservingMergeExec {
             input,
             expr,
             metrics: ExecutionPlanMetricsSet::new(),
+            fetch: None,
         }
     }
+    /// Sets the number of rows to fetch
+    pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
+        self.fetch = fetch;
+        self
+    }
 
     /// Input schema
     pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
@@ -92,6 +100,11 @@ impl SortPreservingMergeExec {
     pub fn expr(&self) -> &[PhysicalSortExpr] {
         &self.expr
     }
+
+    /// Fetch
+    pub fn fetch(&self) -> Option<usize> {
+        self.fetch
+    }
 }
 
 impl ExecutionPlan for SortPreservingMergeExec {
@@ -137,10 +150,10 @@ impl ExecutionPlan for SortPreservingMergeExec {
         self: Arc<Self>,
         children: Vec<Arc<dyn ExecutionPlan>>,
     ) -> Result<Arc<dyn ExecutionPlan>> {
-        Ok(Arc::new(SortPreservingMergeExec::new(
-            self.expr.clone(),
-            children[0].clone(),
-        )))
+        Ok(Arc::new(
+            SortPreservingMergeExec::new(self.expr.clone(), 
children[0].clone())
+                .with_fetch(self.fetch),
+        ))
     }
 
     fn execute(
@@ -192,6 +205,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
                     &self.expr,
                     BaselineMetrics::new(&self.metrics, partition),
                     context.session_config().batch_size(),
+                    self.fetch,
                 )?;
 
                 debug!("Got stream result from 
SortPreservingMergeStream::new_from_receivers");
@@ -209,7 +223,12 @@ impl ExecutionPlan for SortPreservingMergeExec {
         match t {
             DisplayFormatType::Default | DisplayFormatType::Verbose => {
                 let expr: Vec<String> = self.expr.iter().map(|e| 
e.to_string()).collect();
-                write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))
+                write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))?;
+                if let Some(fetch) = self.fetch {
+                    write!(f, ", fetch={fetch}")?;
+                };
+
+                Ok(())
             }
         }
     }
@@ -814,6 +833,7 @@ mod tests {
             sort.as_slice(),
             BaselineMetrics::new(&metrics, 0),
             task_ctx.session_config().batch_size(),
+            None,
         )
         .unwrap();
 
diff --git a/datafusion/core/tests/sql/explain_analyze.rs 
b/datafusion/core/tests/sql/explain_analyze.rs
index 01bdb629ee..e0130cb09c 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -599,7 +599,7 @@ async fn test_physical_plan_display_indent() {
     let physical_plan = dataframe.create_physical_plan().await.unwrap();
     let expected = vec![
         "GlobalLimitExec: skip=0, fetch=10",
-        "  SortPreservingMergeExec: [the_min@2 DESC]",
+        "  SortPreservingMergeExec: [the_min@2 DESC], fetch=10",
         "    SortExec: fetch=10, expr=[the_min@2 DESC]",
         "      ProjectionExec: expr=[c1@0 as c1, MAX(aggregate_test_100.c12)@1 
as MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)@2 as the_min]",
         "        AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], 
aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]",
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part 
b/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part
index d2e06d5ff6..6c662c1091 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part
@@ -71,7 +71,7 @@ Limit: skip=0, fetch=10
 ------------TableScan: nation projection=[n_nationkey, n_name]
 physical_plan
 GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [revenue@2 DESC]
+--SortPreservingMergeExec: [revenue@2 DESC], fetch=10
 ----SortExec: fetch=10, expr=[revenue@2 DESC]
 ------ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name, 
SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@7 as revenue, 
c_acctbal@2 as c_acctbal, n_name@4 as n_name, c_address@5 as c_address, 
c_phone@3 as c_phone, c_comment@6 as c_comment]
 --------AggregateExec: mode=FinalPartitioned, gby=[c_custkey@0 as c_custkey, 
c_name@1 as c_name, c_acctbal@2 as c_acctbal, c_phone@3 as c_phone, n_name@4 as 
n_name, c_address@5 as c_address, c_comment@6 as c_comment], 
aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part 
b/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part
index af29708c67..0c16fe1ab9 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part
@@ -75,7 +75,7 @@ Limit: skip=0, fetch=10
 ----------------------TableScan: nation projection=[n_nationkey, n_name], 
partial_filters=[nation.n_name = Utf8("GERMANY")]
 physical_plan
 GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [value@1 DESC]
+--SortPreservingMergeExec: [value@1 DESC], fetch=10
 ----SortExec: fetch=10, expr=[value@1 DESC]
 ------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, 
SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value]
 --------NestedLoopJoinExec: join_type=Inner, 
filter=CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS 
Decimal128(38, 15)) > SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * 
Float64(0.0001)@1
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part 
b/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part
index 7e5be14271..8ac9576a12 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part
@@ -56,7 +56,7 @@ Limit: skip=0, fetch=10
 ------------------------TableScan: orders projection=[o_orderkey, o_custkey, 
o_comment], partial_filters=[orders.o_comment NOT LIKE 
Utf8("%special%requests%")]
 physical_plan
 GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC]
+--SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC], fetch=10
 ----SortExec: fetch=10, expr=[custdist@1 DESC,c_count@0 DESC]
 ------ProjectionExec: expr=[c_count@0 as c_count, COUNT(UInt8(1))@1 as 
custdist]
 --------AggregateExec: mode=FinalPartitioned, gby=[c_count@0 as c_count], 
aggr=[COUNT(UInt8(1))]
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part 
b/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part
index 677db0329c..58796e93a8 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part
@@ -67,7 +67,7 @@ Limit: skip=0, fetch=10
 ------------------TableScan: supplier projection=[s_suppkey, s_comment], 
partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")]
 physical_plan
 GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS 
LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST]
+--SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS 
LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], fetch=10
 ----SortExec: fetch=10, expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS 
LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST]
 ------ProjectionExec: expr=[group_alias_0@0 as part.p_brand, group_alias_1@1 
as part.p_type, group_alias_2@2 as part.p_size, COUNT(alias1)@3 as supplier_cnt]
 --------AggregateExec: mode=FinalPartitioned, gby=[group_alias_0@0 as 
group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as 
group_alias_2], aggr=[COUNT(alias1)]
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part 
b/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part
index 4ad1ed7293..18cd261b76 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part
@@ -101,7 +101,7 @@ Limit: skip=0, fetch=10
 ----------------------TableScan: region projection=[r_regionkey, r_name], 
partial_filters=[region.r_name = Utf8("EUROPE")]
 physical_plan
 GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 
ASC NULLS LAST,p_partkey@3 ASC NULLS LAST]
+--SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 
ASC NULLS LAST,p_partkey@3 ASC NULLS LAST], fetch=10
 ----SortExec: fetch=10, expr=[s_acctbal@0 DESC,n_name@2 ASC NULLS 
LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST]
 ------ProjectionExec: expr=[s_acctbal@5 as s_acctbal, s_name@2 as s_name, 
n_name@8 as n_name, p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_address@3 
as s_address, s_phone@4 as s_phone, s_comment@6 as s_comment]
 --------CoalesceBatchesExec: target_batch_size=8192
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part 
b/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part
index dc3b150877..f8c1385681 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part
@@ -60,7 +60,7 @@ Limit: skip=0, fetch=10
 ----------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, 
l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate > Date32("9204")]
 physical_plan
 GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST]
+--SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST], 
fetch=10
 ----SortExec: fetch=10, expr=[revenue@1 DESC,o_orderdate@2 ASC NULLS LAST]
 ------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, 
SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@3 as revenue, 
o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority]
 --------AggregateExec: mode=FinalPartitioned, gby=[l_orderkey@0 as l_orderkey, 
o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority], 
aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part 
b/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part
index 756b2e2c7c..45a4be6466 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part
@@ -77,7 +77,7 @@ Limit: skip=0, fetch=10
 --------------TableScan: nation projection=[n_nationkey, n_name]
 physical_plan
 GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC]
+--SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC], fetch=10
 ----SortExec: fetch=10, expr=[nation@0 ASC NULLS LAST,o_year@1 DESC]
 ------ProjectionExec: expr=[nation@0 as nation, o_year@1 as o_year, 
SUM(profit.amount)@2 as sum_profit]
 --------AggregateExec: mode=FinalPartitioned, gby=[nation@0 as nation, 
o_year@1 as o_year], aggr=[SUM(profit.amount)]
diff --git a/datafusion/core/tests/sqllogictests/test_files/union.slt 
b/datafusion/core/tests/sqllogictests/test_files/union.slt
index 94c9eef893..2b3022ddd1 100644
--- a/datafusion/core/tests/sqllogictests/test_files/union.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/union.slt
@@ -308,7 +308,7 @@ Limit: skip=0, fetch=5
 --------TableScan: aggregate_test_100 projection=[c1, c3]
 physical_plan
 GlobalLimitExec: skip=0, fetch=5
---SortPreservingMergeExec: [c9@1 DESC]
+--SortPreservingMergeExec: [c9@1 DESC], fetch=5
 ----UnionExec
 ------SortExec: expr=[c9@1 DESC]
 --------ProjectionExec: expr=[c1@0 as c1, CAST(c9@1 AS Int64) as c9]
diff --git a/datafusion/core/tests/sqllogictests/test_files/window.slt 
b/datafusion/core/tests/sqllogictests/test_files/window.slt
index 08d1a5616e..d77df127a8 100644
--- a/datafusion/core/tests/sqllogictests/test_files/window.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/window.slt
@@ -1792,7 +1792,7 @@ Limit: skip=0, fetch=5
 ------------TableScan: aggregate_test_100 projection=[c2, c3, c9]
 physical_plan
 GlobalLimitExec: skip=0, fetch=5
---SortPreservingMergeExec: [c3@0 ASC NULLS LAST]
+--SortPreservingMergeExec: [c3@0 ASC NULLS LAST], fetch=5
 ----ProjectionExec: expr=[c3@0 as c3, SUM(aggregate_test_100.c9) ORDER BY 
[aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS 
FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING 
AND CURRENT ROW@2 as sum1, SUM(aggregate_test_100.c9) PARTITION BY 
[aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE 
BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum2]
 ------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: 
"SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, 
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, 
start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted]
 --------SortExec: expr=[c3@0 ASC NULLS LAST,c9@1 DESC]
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index de334dc4a5..0d61cd2b35 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1366,6 +1366,8 @@ message SortExecNode {
 message SortPreservingMergeExecNode {
   PhysicalPlanNode input = 1;
   repeated PhysicalExprNode expr = 2;
+  // Maximum number of highest/lowest rows to fetch; negative means no limit
+  int64 fetch = 3;
 }
 
 message CoalesceBatchesExecNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 1cf08be321..831dd49618 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -20269,6 +20269,9 @@ impl serde::Serialize for SortPreservingMergeExecNode {
         if !self.expr.is_empty() {
             len += 1;
         }
+        if self.fetch != 0 {
+            len += 1;
+        }
         let mut struct_ser = 
serializer.serialize_struct("datafusion.SortPreservingMergeExecNode", len)?;
         if let Some(v) = self.input.as_ref() {
             struct_ser.serialize_field("input", v)?;
@@ -20276,6 +20279,9 @@ impl serde::Serialize for SortPreservingMergeExecNode {
         if !self.expr.is_empty() {
             struct_ser.serialize_field("expr", &self.expr)?;
         }
+        if self.fetch != 0 {
+            struct_ser.serialize_field("fetch", 
ToString::to_string(&self.fetch).as_str())?;
+        }
         struct_ser.end()
     }
 }
@@ -20288,12 +20294,14 @@ impl<'de> serde::Deserialize<'de> for 
SortPreservingMergeExecNode {
         const FIELDS: &[&str] = &[
             "input",
             "expr",
+            "fetch",
         ];
 
         #[allow(clippy::enum_variant_names)]
         enum GeneratedField {
             Input,
             Expr,
+            Fetch,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -20317,6 +20325,7 @@ impl<'de> serde::Deserialize<'de> for 
SortPreservingMergeExecNode {
                         match value {
                             "input" => Ok(GeneratedField::Input),
                             "expr" => Ok(GeneratedField::Expr),
+                            "fetch" => Ok(GeneratedField::Fetch),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -20338,6 +20347,7 @@ impl<'de> serde::Deserialize<'de> for 
SortPreservingMergeExecNode {
             {
                 let mut input__ = None;
                 let mut expr__ = None;
+                let mut fetch__ = None;
                 while let Some(k) = map.next_key()? {
                     match k {
                         GeneratedField::Input => {
@@ -20352,11 +20362,20 @@ impl<'de> serde::Deserialize<'de> for 
SortPreservingMergeExecNode {
                             }
                             expr__ = Some(map.next_value()?);
                         }
+                        GeneratedField::Fetch => {
+                            if fetch__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("fetch"));
+                            }
+                            fetch__ = 
+                                
Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+                            ;
+                        }
                     }
                 }
                 Ok(SortPreservingMergeExecNode {
                     input: input__,
                     expr: expr__.unwrap_or_default(),
+                    fetch: fetch__.unwrap_or_default(),
                 })
             }
         }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 5f201b124d..e6c076e7d4 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1926,6 +1926,9 @@ pub struct SortPreservingMergeExecNode {
     pub input: 
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
     #[prost(message, repeated, tag = "2")]
     pub expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
+    /// Maximum number of highest/lowest rows to fetch; negative means no limit
+    #[prost(int64, tag = "3")]
+    pub fetch: i64,
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 1daa1c2e4b..7bbbe13568 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -692,7 +692,14 @@ impl AsExecutionPlan for PhysicalPlanNode {
                         }
                     })
                     .collect::<Result<Vec<_>, _>>()?;
-                Ok(Arc::new(SortPreservingMergeExec::new(exprs, input)))
+                let fetch = if sort.fetch < 0 {
+                    None
+                } else {
+                    Some(sort.fetch as usize)
+                };
+                Ok(Arc::new(
+                    SortPreservingMergeExec::new(exprs, 
input).with_fetch(fetch),
+                ))
             }
             PhysicalPlanType::Extension(extension) => {
                 let inputs: Vec<Arc<dyn ExecutionPlan>> = extension
@@ -1144,6 +1151,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     Box::new(protobuf::SortPreservingMergeExecNode {
                         input: Some(Box::new(input)),
                         expr,
+                        fetch: exec.fetch().map(|f| f as i64).unwrap_or(-1),
                     }),
                 )),
             })

Reply via email to