This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 80abc9491 Pipeline-friendly Bounded Memory Window Executor (#4777)
80abc9491 is described below
commit 80abc9491015cd2ae5879ad47bf6f386017cf64e
Author: Mustafa akur <[email protected]>
AuthorDate: Wed Jan 4 23:30:01 2023 +0300
Pipeline-friendly Bounded Memory Window Executor (#4777)
* Sort Removal rule initial commit
* move ordering satisfy to the util
* update test and change repartition maintain_input_order impl
* simplifications
* partition by refactor (#28)
* partition by refactor
* minor changes
* Unnecessary tuple to Range conversion is removed
* move transpose under common
* Add naive sort removal rule
* Add todo for finer Sort removal handling
* Refactors to improve readability and reduce nesting
* reverse expr returns Option (no need for support check)
* fix tests
* partition by and order by no longer ends up at the same window group
* Bounded window exec
* solve merge problems
* Refactor to simplify code
* Better comments, change method names
* resolve merge conflicts
* Resolve errors introduced by syncing
* remove set_state, make ntile debuggable
* remove locked flag
* address reviews
* address reviews
* Resolve merge conflict
* address reviews
* address reviews
* address reviews
* Add new tests
* Update tests
* add support for bounded min max
* address reviews
* rename sort rule
* Resolve merge conflicts
* refactors
* Update fuzzy tests + minor changes
* Simplify code and improve comments
* Fix imports, make create_schema more functional
* address reviews
* undo yml change
* minor change to pass from CI
* resolve merge conflicts
* rename some members
* Move rule to physical planning
* Minor stylistic/comment changes
* Simplify batch-merging utility functions
* Remove unnecessary clones, simplify code
* update cargo lock file
* address reviews
* update comments
* resolve linter error
* Tidy up comments after final review
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
.github/workflows/rust.yml | 1 +
datafusion-cli/Cargo.lock | 2 +
datafusion/core/Cargo.toml | 1 +
datafusion/core/src/execution/context.rs | 43 +-
.../core/src/physical_optimizer/optimize_sorts.rs | 64 +-
.../src/physical_optimizer/pipeline_checker.rs | 4 +-
datafusion/core/src/physical_plan/common.rs | 42 ++
datafusion/core/src/physical_plan/planner.rs | 31 +-
.../windows/bounded_window_agg_exec.rs | 705 +++++++++++++++++++++
datafusion/core/src/physical_plan/windows/mod.rs | 2 +
datafusion/core/tests/sql/window.rs | 280 +++++++-
datafusion/core/tests/window_fuzz.rs | 385 +++++++++++
datafusion/physical-expr/Cargo.toml | 1 +
datafusion/physical-expr/src/aggregate/count.rs | 4 +
datafusion/physical-expr/src/aggregate/min_max.rs | 20 +-
datafusion/physical-expr/src/aggregate/mod.rs | 6 +
datafusion/physical-expr/src/aggregate/sum.rs | 4 +
datafusion/physical-expr/src/window/aggregate.rs | 10 +-
datafusion/physical-expr/src/window/built_in.rs | 120 +++-
.../src/window/built_in_window_function_expr.rs | 8 +
datafusion/physical-expr/src/window/cume_dist.rs | 1 +
datafusion/physical-expr/src/window/lead_lag.rs | 78 ++-
datafusion/physical-expr/src/window/mod.rs | 6 +
datafusion/physical-expr/src/window/nth_value.rs | 38 +-
datafusion/physical-expr/src/window/ntile.rs | 1 +
.../src/window/partition_evaluator.rs | 36 +-
datafusion/physical-expr/src/window/rank.rs | 64 +-
datafusion/physical-expr/src/window/row_number.rs | 34 +-
.../physical-expr/src/window/sliding_aggregate.rs | 204 ++++--
datafusion/physical-expr/src/window/window_expr.rs | 138 +++-
test-utils/src/lib.rs | 5 +-
31 files changed, 2205 insertions(+), 133 deletions(-)
diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml
index c15615d59..c4d8cd533 100644
--- a/.github/workflows/rust.yml
+++ b/.github/workflows/rust.yml
@@ -64,6 +64,7 @@ jobs:
- name: Check Cargo.lock for datafusion-cli
run: |
# If this test fails, try running `cargo update` in the
`datafusion-cli` directory
+ # and check in the updated Cargo.lock file.
cargo check --manifest-path datafusion-cli/Cargo.toml --locked
# test the crate
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 5c0c00b7a..a34d9f4d4 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -673,6 +673,7 @@ dependencies = [
"futures",
"glob",
"hashbrown 0.13.1",
+ "indexmap",
"itertools",
"lazy_static",
"log",
@@ -765,6 +766,7 @@ dependencies = [
"datafusion-row",
"half",
"hashbrown 0.13.1",
+ "indexmap",
"itertools",
"lazy_static",
"md-5",
diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml
index 31b85cb18..44c25f593 100644
--- a/datafusion/core/Cargo.toml
+++ b/datafusion/core/Cargo.toml
@@ -75,6 +75,7 @@ flate2 = { version = "1.0.24", optional = true }
futures = "0.3"
glob = "0.3.0"
hashbrown = { version = "0.13", features = ["raw"] }
+indexmap = "1.9.2"
itertools = "0.10"
lazy_static = { version = "^1.4.0" }
log = "^0.4"
diff --git a/datafusion/core/src/execution/context.rs
b/datafusion/core/src/execution/context.rs
index 45131493e..32cf8d165 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -1453,26 +1453,25 @@ impl SessionState {
// We need to take care of the rule ordering. They may influence each
other.
let physical_optimizers: Vec<Arc<dyn PhysicalOptimizerRule + Sync +
Send>> = vec![
Arc::new(AggregateStatistics::new()),
- // - In order to increase the parallelism, it will change the
output partitioning
- // of some operators in the plan tree, which will influence other
rules.
- // Therefore, it should be run as soon as possible.
- // - The reason to make it optional is
- // - it's not used for the distributed engine, Ballista.
- // - it's conflicted with some parts of the BasicEnforcement,
since it will
- // introduce additional repartitioning while the
BasicEnforcement aims at
- // reducing unnecessary repartitioning.
+ // In order to increase the parallelism, the Repartition rule will
change the
+ // output partitioning of some operators in the plan tree, which
will influence
+ // other rules. Therefore, it should run as soon as possible. It
is optional because:
+ // - It's not used for the distributed engine, Ballista.
+ // - It's conflicted with some parts of the BasicEnforcement,
since it will
+ // introduce additional repartitioning while the
BasicEnforcement aims at
+ // reducing unnecessary repartitioning.
Arc::new(Repartition::new()),
- //- Currently it will depend on the partition number to decide
whether to change the
- // single node sort to parallel local sort and merge. Therefore,
it should be run
- // after the Repartition.
- // - Since it will change the output ordering of some operators,
it should be run
+ // - Currently it will depend on the partition number to decide
whether to change the
+ // single node sort to parallel local sort and merge. Therefore,
GlobalSortSelection
+ // should run after the Repartition.
+ // - Since it will change the output ordering of some operators,
it should run
// before JoinSelection and BasicEnforcement, which may depend on
that.
Arc::new(GlobalSortSelection::new()),
- // Statistics-base join selection will change the Auto mode to
real join implementation,
+ // Statistics-based join selection will change the Auto mode to a
real join implementation,
// like collect left, or hash join, or future sort merge join,
which will
// influence the BasicEnforcement to decide whether to add
additional repartition
// and local sort to meet the distribution and ordering
requirements.
- // Therefore, it should be run before BasicEnforcement
+ // Therefore, it should run before BasicEnforcement.
Arc::new(JoinSelection::new()),
// If the query is processing infinite inputs, the PipelineFixer
rule applies the
// necessary transformations to make the query runnable (if it is
not already runnable).
@@ -1480,17 +1479,17 @@ impl SessionState {
// Since the transformations it applies may alter output
partitioning properties of operators
// (e.g. by swapping hash join sides), this rule runs before
BasicEnforcement.
Arc::new(PipelineFixer::new()),
- // It's for adding essential repartition and local sorting
operator to satisfy the
- // required distribution and local sort.
+ // BasicEnforcement is for adding essential repartition and local
sorting operators
+ // to satisfy the required distribution and local sort
requirements.
// Please make sure that the whole plan tree is determined.
Arc::new(BasicEnforcement::new()),
- // `BasicEnforcement` stage conservatively inserts `SortExec`s to
satisfy ordering requirements.
- // However, a deeper analysis may sometimes reveal that such a
`SortExec` is actually unnecessary.
- // These cases typically arise when we have reversible
`WindowAggExec`s or deep subqueries. The
- // rule below performs this analysis and removes unnecessary
`SortExec`s.
+ // The BasicEnforcement stage conservatively inserts sorts to
satisfy ordering requirements.
+ // However, a deeper analysis may sometimes reveal that such a
sort is actually unnecessary.
+ // These cases typically arise when we have reversible window
expressions or deep subqueries.
+ // The rule below performs this analysis and removes unnecessary
sorts.
Arc::new(OptimizeSorts::new()),
- // It will not influence the distribution and ordering of the
whole plan tree.
- // Therefore, to avoid influencing other rules, it should be run
at last.
+ // The CoalesceBatches rule will not influence the distribution
and ordering of the
+ // whole plan tree. Therefore, to avoid influencing other rules,
it should run last.
Arc::new(CoalesceBatches::new()),
// The PipelineChecker rule will reject non-runnable query plans
that use
// pipeline-breaking operators on infinite input(s). The rule
generates a
diff --git a/datafusion/core/src/physical_optimizer/optimize_sorts.rs
b/datafusion/core/src/physical_optimizer/optimize_sorts.rs
index a47026cc7..17b27bfa7 100644
--- a/datafusion/core/src/physical_optimizer/optimize_sorts.rs
+++ b/datafusion/core/src/physical_optimizer/optimize_sorts.rs
@@ -33,10 +33,11 @@ use crate::physical_optimizer::utils::{
use crate::physical_optimizer::PhysicalOptimizerRule;
use crate::physical_plan::rewrite::TreeNodeRewritable;
use crate::physical_plan::sorts::sort::SortExec;
-use crate::physical_plan::windows::WindowAggExec;
+use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec};
use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan};
use arrow::datatypes::SchemaRef;
use datafusion_common::{reverse_sort_options, DataFusionError};
+use datafusion_physical_expr::window::WindowExpr;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use itertools::izip;
use std::iter::zip;
@@ -181,17 +182,32 @@ fn optimize_sorts(
sort_exec.input().equivalence_properties()
}) {
update_child_to_remove_unnecessary_sort(child,
sort_onwards)?;
- } else if let Some(window_agg_exec) =
+ }
+ // For window expressions, we can remove some sorts when
we can
+ // calculate the result in reverse:
+ else if let Some(exec) =
requirements.plan.as_any().downcast_ref::<WindowAggExec>()
{
- // For window expressions, we can remove some sorts
when we can
- // calculate the result in reverse:
- if let Some(res) = analyze_window_sort_removal(
- window_agg_exec,
+ if let Some(result) = analyze_window_sort_removal(
+ exec.window_expr(),
+ &exec.partition_keys,
+ sort_exec,
+ sort_onwards,
+ )? {
+ return Ok(Some(result));
+ }
+ } else if let Some(exec) = requirements
+ .plan
+ .as_any()
+ .downcast_ref::<BoundedWindowAggExec>()
+ {
+ if let Some(result) = analyze_window_sort_removal(
+ exec.window_expr(),
+ &exec.partition_keys,
sort_exec,
sort_onwards,
)? {
- return Ok(Some(res));
+ return Ok(Some(result));
}
}
// TODO: Once we can ensure that required ordering
information propagates with
@@ -273,9 +289,11 @@ fn analyze_immediate_sort_removal(
Ok(None)
}
-/// Analyzes a `WindowAggExec` to determine whether it may allow removing a
sort.
+/// Analyzes a [WindowAggExec] or a [BoundedWindowAggExec] to determine whether
+/// it may allow removing a sort.
fn analyze_window_sort_removal(
- window_agg_exec: &WindowAggExec,
+ window_expr: &[Arc<dyn WindowExpr>],
+ partition_keys: &[Arc<dyn PhysicalExpr>],
sort_exec: &SortExec,
sort_onward: &mut Vec<(usize, Arc<dyn ExecutionPlan>)>,
) -> Result<Option<PlanWithCorrespondingSort>> {
@@ -289,7 +307,6 @@ fn analyze_window_sort_removal(
// If there is no physical ordering, there is no way to remove a sort
-- immediately return:
return Ok(None);
};
- let window_expr = window_agg_exec.window_expr();
let (can_skip_sorting, should_reverse) = can_skip_sort(
window_expr[0].partition_by(),
required_ordering,
@@ -308,13 +325,26 @@ fn analyze_window_sort_removal(
if let Some(window_expr) = new_window_expr {
let new_child =
remove_corresponding_sort_from_sub_plan(sort_onward)?;
let new_schema = new_child.schema();
- let new_plan = Arc::new(WindowAggExec::try_new(
- window_expr,
- new_child,
- new_schema,
- window_agg_exec.partition_keys.clone(),
- Some(physical_ordering.to_vec()),
- )?);
+
+ let uses_bounded_memory = window_expr.iter().all(|e|
e.uses_bounded_memory());
+ // If all window exprs can run with bounded memory choose bounded
window variant
+ let new_plan = if uses_bounded_memory {
+ Arc::new(BoundedWindowAggExec::try_new(
+ window_expr,
+ new_child,
+ new_schema,
+ partition_keys.to_vec(),
+ Some(physical_ordering.to_vec()),
+ )?) as _
+ } else {
+ Arc::new(WindowAggExec::try_new(
+ window_expr,
+ new_child,
+ new_schema,
+ partition_keys.to_vec(),
+ Some(physical_ordering.to_vec()),
+ )?) as _
+ };
return Ok(Some(PlanWithCorrespondingSort::new(new_plan)));
}
}
diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs
b/datafusion/core/src/physical_optimizer/pipeline_checker.rs
index c35ef29f2..96f0b0ff6 100644
--- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs
+++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs
@@ -301,7 +301,7 @@ mod sql_tests {
let case = QueryCase {
sql: "SELECT
c9,
- SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN
1 PRECEDING AND 5 FOLLOWING) as sum1
+ SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN
1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1
FROM test
LIMIT 5".to_string(),
cases: vec![Arc::new(test1), Arc::new(test2)],
@@ -325,7 +325,7 @@ mod sql_tests {
let case = QueryCase {
sql: "SELECT
c9,
- SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING
AND 5 FOLLOWING) as sum1
+ SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING
AND UNBOUNDED FOLLOWING) as sum1
FROM test".to_string(),
cases: vec![Arc::new(test1), Arc::new(test2)],
error_operator: "Window Error".to_string()
diff --git a/datafusion/core/src/physical_plan/common.rs
b/datafusion/core/src/physical_plan/common.rs
index 7ea8dfe35..f08652f7c 100644
--- a/datafusion/core/src/physical_plan/common.rs
+++ b/datafusion/core/src/physical_plan/common.rs
@@ -22,6 +22,7 @@ use crate::error::{DataFusionError, Result};
use crate::execution::context::TaskContext;
use crate::physical_plan::metrics::MemTrackingMetrics;
use crate::physical_plan::{displayable, ColumnStatistics, ExecutionPlan,
Statistics};
+use arrow::compute::concat;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::ArrowError;
use arrow::error::Result as ArrowResult;
@@ -95,6 +96,47 @@ pub async fn collect(stream: SendableRecordBatchStream) ->
Result<Vec<RecordBatc
.map_err(DataFusionError::from)
}
+/// Merge two record batch references into a single record batch.
+/// All the record batches inside the slice must have the same schema.
+pub fn merge_batches(
+ first: &RecordBatch,
+ second: &RecordBatch,
+ schema: SchemaRef,
+) -> ArrowResult<RecordBatch> {
+ let columns = (0..schema.fields.len())
+ .map(|index| {
+ let first_column = first.column(index).as_ref();
+ let second_column = second.column(index).as_ref();
+ concat(&[first_column, second_column])
+ })
+ .collect::<ArrowResult<Vec<_>>>()?;
+ RecordBatch::try_new(schema, columns)
+}
+
+/// Merge a slice of record batch references into a single record batch, or
+/// return None if the slice itself is empty. All the record batches inside the
+/// slice must have the same schema.
+pub fn merge_multiple_batches(
+ batches: &[&RecordBatch],
+ schema: SchemaRef,
+) -> ArrowResult<Option<RecordBatch>> {
+ Ok(if batches.is_empty() {
+ None
+ } else {
+ let columns = (0..schema.fields.len())
+ .map(|index| {
+ concat(
+ &batches
+ .iter()
+ .map(|batch| batch.column(index).as_ref())
+ .collect::<Vec<_>>(),
+ )
+ })
+ .collect::<ArrowResult<Vec<_>>>()?;
+ Some(RecordBatch::try_new(schema, columns)?)
+ })
+}
+
/// Recursively builds a list of files in a directory with a given extension
pub fn build_checked_file_list(dir: &str, ext: &str) -> Result<Vec<String>> {
let mut filenames: Vec<String> = Vec::new();
diff --git a/datafusion/core/src/physical_plan/planner.rs
b/datafusion/core/src/physical_plan/planner.rs
index 978ed195b..217ab2aa4 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -47,7 +47,7 @@ use crate::physical_plan::limit::{GlobalLimitExec,
LocalLimitExec};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::repartition::RepartitionExec;
use crate::physical_plan::sorts::sort::SortExec;
-use crate::physical_plan::windows::WindowAggExec;
+use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec};
use crate::physical_plan::{joins::utils as join_utils, Partitioning};
use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr,
WindowExpr};
use crate::{
@@ -614,13 +614,28 @@ impl DefaultPhysicalPlanner {
})
.collect::<Result<Vec<_>>>()?;
- Ok(Arc::new(WindowAggExec::try_new(
- window_expr,
- input_exec,
- physical_input_schema,
- physical_partition_keys,
- physical_sort_keys,
- )?))
+ let uses_bounded_memory = window_expr
+ .iter()
+ .all(|e| e.uses_bounded_memory());
+ // If all window expressions can run with bounded memory,
+ // choose the bounded window variant:
+ Ok(if uses_bounded_memory {
+ Arc::new(BoundedWindowAggExec::try_new(
+ window_expr,
+ input_exec,
+ physical_input_schema,
+ physical_partition_keys,
+ physical_sort_keys,
+ )?)
+ } else {
+ Arc::new(WindowAggExec::try_new(
+ window_expr,
+ input_exec,
+ physical_input_schema,
+ physical_partition_keys,
+ physical_sort_keys,
+ )?)
+ })
}
LogicalPlan::Aggregate(Aggregate {
input,
diff --git
a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs
b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs
new file mode 100644
index 000000000..5ed6a112c
--- /dev/null
+++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs
@@ -0,0 +1,705 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Stream and channel implementations for window function expressions.
+//! The executor given here uses bounded memory (does not maintain all
+//! the input data seen so far), which makes it appropriate when processing
+//! infinite inputs.
+
+use crate::error::Result;
+use crate::execution::context::TaskContext;
+use crate::physical_plan::expressions::PhysicalSortExpr;
+use crate::physical_plan::metrics::{
+ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
+};
+use crate::physical_plan::{
+ ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan,
Partitioning,
+ RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr,
+};
+use arrow::array::Array;
+use arrow::compute::{concat, lexicographical_partition_ranges, SortColumn};
+use arrow::{
+ array::ArrayRef,
+ datatypes::{Schema, SchemaRef},
+ error::Result as ArrowResult,
+ record_batch::RecordBatch,
+};
+use datafusion_common::{DataFusionError, ScalarValue};
+use futures::stream::Stream;
+use futures::{ready, StreamExt};
+use std::any::Any;
+use std::cmp::min;
+use std::collections::HashMap;
+use std::ops::Range;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+use crate::physical_plan::common::merge_batches;
+use datafusion_physical_expr::window::{
+ PartitionBatchState, PartitionBatches, PartitionKey,
PartitionWindowAggStates,
+ WindowAggState, WindowState,
+};
+use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
+use indexmap::IndexMap;
+use log::debug;
+
+/// Window execution plan
+#[derive(Debug)]
+pub struct BoundedWindowAggExec {
+ /// Input plan
+ input: Arc<dyn ExecutionPlan>,
+ /// Window function expression
+ window_expr: Vec<Arc<dyn WindowExpr>>,
+ /// Schema after the window is run
+ schema: SchemaRef,
+ /// Schema before the window
+ input_schema: SchemaRef,
+ /// Partition Keys
+ pub partition_keys: Vec<Arc<dyn PhysicalExpr>>,
+ /// Sort Keys
+ pub sort_keys: Option<Vec<PhysicalSortExpr>>,
+ /// Execution metrics
+ metrics: ExecutionPlanMetricsSet,
+}
+
+impl BoundedWindowAggExec {
+ /// Create a new execution plan for window aggregates
+ pub fn try_new(
+ window_expr: Vec<Arc<dyn WindowExpr>>,
+ input: Arc<dyn ExecutionPlan>,
+ input_schema: SchemaRef,
+ partition_keys: Vec<Arc<dyn PhysicalExpr>>,
+ sort_keys: Option<Vec<PhysicalSortExpr>>,
+ ) -> Result<Self> {
+ let schema = create_schema(&input_schema, &window_expr)?;
+ let schema = Arc::new(schema);
+ Ok(Self {
+ input,
+ window_expr,
+ schema,
+ input_schema,
+ partition_keys,
+ sort_keys,
+ metrics: ExecutionPlanMetricsSet::new(),
+ })
+ }
+
+ /// Window expressions
+ pub fn window_expr(&self) -> &[Arc<dyn WindowExpr>] {
+ &self.window_expr
+ }
+
+ /// Input plan
+ pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
+ &self.input
+ }
+
+ /// Get the input schema before any window functions are applied
+ pub fn input_schema(&self) -> SchemaRef {
+ self.input_schema.clone()
+ }
+
+ /// Return the output sort order of partition keys: For example
+ /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a
+ // We are sure that partition by columns are always at the beginning of
sort_keys
+ // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY`
columns can be used safely
+ // to calculate partition separation points
+ pub fn partition_by_sort_keys(&self) -> Result<Vec<PhysicalSortExpr>> {
+ let mut result = vec![];
+ // All window exprs have the same partition by, so we just use the
first one:
+ let partition_by = self.window_expr()[0].partition_by();
+ let sort_keys = self.sort_keys.as_deref().unwrap_or(&[]);
+ for item in partition_by {
+ if let Some(a) = sort_keys.iter().find(|&e| e.expr.eq(item)) {
+ result.push(a.clone());
+ } else {
+ return Err(DataFusionError::Internal(
+ "Partition key not found in sort keys".to_string(),
+ ));
+ }
+ }
+ Ok(result)
+ }
+}
+
+impl ExecutionPlan for BoundedWindowAggExec {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+
+ fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+ vec![self.input.clone()]
+ }
+
+ /// Get the output partitioning of this plan
+ fn output_partitioning(&self) -> Partitioning {
+ // As we can have repartitioning using the partition keys, this can
+ // be either one or more than one, depending on the presence of
+ // repartitioning.
+ self.input.output_partitioning()
+ }
+
+ fn unbounded_output(&self, children: &[bool]) -> Result<bool> {
+ Ok(children[0])
+ }
+
+ fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
+ self.input().output_ordering()
+ }
+
+ fn required_input_ordering(&self) -> Vec<Option<&[PhysicalSortExpr]>> {
+ let sort_keys = self.sort_keys.as_deref();
+ vec![sort_keys]
+ }
+
+ fn required_input_distribution(&self) -> Vec<Distribution> {
+ if self.partition_keys.is_empty() {
+ debug!("No partition defined for BoundedWindowAggExec!!!");
+ vec![Distribution::SinglePartition]
+ } else {
+ //TODO support PartitionCollections if there is no common
partition columns in the window_expr
+ vec![Distribution::HashPartitioned(self.partition_keys.clone())]
+ }
+ }
+
+ fn equivalence_properties(&self) -> EquivalenceProperties {
+ self.input().equivalence_properties()
+ }
+
+ fn maintains_input_order(&self) -> bool {
+ true
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ children: Vec<Arc<dyn ExecutionPlan>>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ Ok(Arc::new(BoundedWindowAggExec::try_new(
+ self.window_expr.clone(),
+ children[0].clone(),
+ self.input_schema.clone(),
+ self.partition_keys.clone(),
+ self.sort_keys.clone(),
+ )?))
+ }
+
+ fn execute(
+ &self,
+ partition: usize,
+ context: Arc<TaskContext>,
+ ) -> Result<SendableRecordBatchStream> {
+ let input = self.input.execute(partition, context)?;
+ let stream = Box::pin(SortedPartitionByBoundedWindowStream::new(
+ self.schema.clone(),
+ self.window_expr.clone(),
+ input,
+ BaselineMetrics::new(&self.metrics, partition),
+ self.partition_by_sort_keys()?,
+ ));
+ Ok(stream)
+ }
+
+ fn fmt_as(
+ &self,
+ t: DisplayFormatType,
+ f: &mut std::fmt::Formatter,
+ ) -> std::fmt::Result {
+ match t {
+ DisplayFormatType::Default => {
+ write!(f, "BoundedWindowAggExec: ")?;
+ let g: Vec<String> = self
+ .window_expr
+ .iter()
+ .map(|e| {
+ format!(
+ "{}: {:?}, frame: {:?}",
+ e.name().to_owned(),
+ e.field(),
+ e.get_window_frame()
+ )
+ })
+ .collect();
+ write!(f, "wdw=[{}]", g.join(", "))?;
+ }
+ }
+ Ok(())
+ }
+
+ fn metrics(&self) -> Option<MetricsSet> {
+ Some(self.metrics.clone_inner())
+ }
+
+ fn statistics(&self) -> Statistics {
+ let input_stat = self.input.statistics();
+ let win_cols = self.window_expr.len();
+ let input_cols = self.input_schema.fields().len();
+ // TODO stats: some windowing function will maintain invariants such
as min, max...
+ let mut column_statistics = Vec::with_capacity(win_cols + input_cols);
+ if let Some(input_col_stats) = input_stat.column_statistics {
+ column_statistics.extend(input_col_stats);
+ } else {
+ column_statistics.extend(vec![ColumnStatistics::default();
input_cols]);
+ }
+ column_statistics.extend(vec![ColumnStatistics::default(); win_cols]);
+ Statistics {
+ is_exact: input_stat.is_exact,
+ num_rows: input_stat.num_rows,
+ column_statistics: Some(column_statistics),
+ total_byte_size: None,
+ }
+ }
+}
+
+fn create_schema(
+ input_schema: &Schema,
+ window_expr: &[Arc<dyn WindowExpr>],
+) -> Result<Schema> {
+ let mut fields = Vec::with_capacity(input_schema.fields().len() +
window_expr.len());
+ fields.extend_from_slice(input_schema.fields());
+ // append results to the schema
+ for expr in window_expr {
+ fields.push(expr.field()?);
+ }
+ Ok(Schema::new(fields))
+}
+
+/// This trait defines the interface for updating the state and calculating
+/// results for window functions. Depending on the partitioning scheme, one
+/// may have different implementations for the functions within.
+pub trait PartitionByHandler {
+ /// Constructs output columns from window_expression results.
+ fn calculate_out_columns(&self) -> Result<Option<Vec<ArrayRef>>>;
+ /// Prunes the window state to remove any unnecessary information
+ /// given how many rows we emitted so far.
+ fn prune_state(&mut self, n_out: usize) -> Result<()>;
+ /// Updates record batches for each partition when new batches are
+ /// received.
+ fn update_partition_batch(&mut self, record_batch: RecordBatch) ->
Result<()>;
+}
+
+/// stream for window aggregation plan
+/// assuming partition by column is sorted (or without PARTITION BY expression)
+pub struct SortedPartitionByBoundedWindowStream {
+ schema: SchemaRef,
+ input: SendableRecordBatchStream,
+ /// The record batch executor receives as input (i.e. the columns needed
+ /// while calculating aggregation results).
+ input_buffer: RecordBatch,
+ /// We separate `input_buffer_record_batch` based on partitions (as
+ /// determined by PARTITION BY columns) and store them per partition
+ /// in `partition_batches`. We use this variable when calculating results
+ /// for each window expression. This enables us to use the same batch for
+ /// different window expressions without copying.
+ // Note that we could keep record batches for each window expression in
+ // `PartitionWindowAggStates`. However, this would use more memory (as
+ // many times as the number of window expressions).
+ partition_buffers: PartitionBatches,
+ /// An executor can run multiple window expressions if the PARTITION BY
+ /// and ORDER BY sections are same. We keep state of the each window
+ /// expression inside `window_agg_states`.
+ window_agg_states: Vec<PartitionWindowAggStates>,
+ finished: bool,
+ window_expr: Vec<Arc<dyn WindowExpr>>,
+ partition_by_sort_keys: Vec<PhysicalSortExpr>,
+ baseline_metrics: BaselineMetrics,
+}
+
+impl PartitionByHandler for SortedPartitionByBoundedWindowStream {
+ /// This method constructs output columns using the result of each window
expression
+ fn calculate_out_columns(&self) -> Result<Option<Vec<ArrayRef>>> {
+ let n_out = self.calculate_n_out_row();
+ if n_out == 0 {
+ Ok(None)
+ } else {
+ self.input_buffer
+ .columns()
+ .iter()
+ .map(|elem| Ok(elem.slice(0, n_out)))
+ .chain(
+ self.window_agg_states
+ .iter()
+ .map(|elem| get_aggregate_result_out_column(elem,
n_out)),
+ )
+ .collect::<Result<Vec<_>>>()
+ .map(Some)
+ }
+ }
+
+ /// Prunes sections of the state that are no longer needed when calculating
+ /// results (as determined by window frame boundaries and number of
results generated).
+ // For instance, if first `n` (not necessarily same with `n_out`) elements
are no longer needed to
+ // calculate window expression result (outside the window frame boundary)
we retract first `n` elements
+ // from `self.partition_batches` in corresponding partition.
+ // For instance, if `n_out` number of rows are calculated, we can remove
+ // first `n_out` rows from `self.input_buffer_record_batch`.
+ fn prune_state(&mut self, n_out: usize) -> Result<()> {
+ // Prune `self.partition_batches`:
+ self.prune_partition_batches()?;
+ // Prune `self.input_buffer_record_batch`:
+ self.prune_input_batch(n_out)?;
+ // Prune `self.window_agg_states`:
+ self.prune_out_columns(n_out)?;
+ Ok(())
+ }
+
+ fn update_partition_batch(&mut self, record_batch: RecordBatch) ->
Result<()> {
+ let partition_columns = self.partition_columns(&record_batch)?;
+ let num_rows = record_batch.num_rows();
+ if num_rows > 0 {
+ let partition_points =
+ self.evaluate_partition_points(num_rows, &partition_columns)?;
+ for partition_range in partition_points {
+ let partition_row = partition_columns
+ .iter()
+ .map(|arr| {
+ ScalarValue::try_from_array(&arr.values,
partition_range.start)
+ })
+ .collect::<Result<PartitionKey>>()?;
+ let partition_batch = record_batch.slice(
+ partition_range.start,
+ partition_range.end - partition_range.start,
+ );
+ if let Some(partition_batch_state) =
+ self.partition_buffers.get_mut(&partition_row)
+ {
+ partition_batch_state.record_batch = merge_batches(
+ &partition_batch_state.record_batch,
+ &partition_batch,
+ self.input.schema(),
+ )?;
+ } else {
+ let partition_batch_state = PartitionBatchState {
+ record_batch: partition_batch,
+ is_end: false,
+ };
+ self.partition_buffers
+ .insert(partition_row, partition_batch_state);
+ };
+ }
+ }
+ let n_partitions = self.partition_buffers.len();
+ for (idx, (_, partition_batch_state)) in
+ self.partition_buffers.iter_mut().enumerate()
+ {
+ partition_batch_state.is_end |= idx < n_partitions - 1;
+ }
+ self.input_buffer = if self.input_buffer.num_rows() == 0 {
+ record_batch
+ } else {
+ merge_batches(&self.input_buffer, &record_batch,
self.input.schema())?
+ };
+
+ Ok(())
+ }
+}
+
+impl Stream for SortedPartitionByBoundedWindowStream {
+ type Item = ArrowResult<RecordBatch>;
+
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ let poll = self.poll_next_inner(cx);
+ self.baseline_metrics.record_poll(poll)
+ }
+}
+
+impl SortedPartitionByBoundedWindowStream {
+ /// Create a new BoundedWindowAggStream
+ pub fn new(
+ schema: SchemaRef,
+ window_expr: Vec<Arc<dyn WindowExpr>>,
+ input: SendableRecordBatchStream,
+ baseline_metrics: BaselineMetrics,
+ partition_by_sort_keys: Vec<PhysicalSortExpr>,
+ ) -> Self {
+ let state = window_expr.iter().map(|_| IndexMap::new()).collect();
+ let empty_batch = RecordBatch::new_empty(schema.clone());
+ Self {
+ schema,
+ input,
+ input_buffer: empty_batch,
+ partition_buffers: IndexMap::new(),
+ window_agg_states: state,
+ finished: false,
+ window_expr,
+ baseline_metrics,
+ partition_by_sort_keys,
+ }
+ }
+
+ fn compute_aggregates(&mut self) -> ArrowResult<RecordBatch> {
+ // calculate window cols
+ for (cur_window_expr, state) in
+ self.window_expr.iter().zip(&mut self.window_agg_states)
+ {
+ cur_window_expr.evaluate_stateful(&self.partition_buffers, state)?;
+ }
+
+ let schema = self.schema.clone();
+ let columns_to_show = self.calculate_out_columns()?;
+ if let Some(columns_to_show) = columns_to_show {
+ let n_generated = columns_to_show[0].len();
+ self.prune_state(n_generated)?;
+ RecordBatch::try_new(schema, columns_to_show)
+ } else {
+ Ok(RecordBatch::new_empty(schema))
+ }
+ }
+
+ #[inline]
+ fn poll_next_inner(
+ &mut self,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<ArrowResult<RecordBatch>>> {
+ if self.finished {
+ return Poll::Ready(None);
+ }
+
+ let result = match ready!(self.input.poll_next_unpin(cx)) {
+ Some(Ok(batch)) => {
+ self.update_partition_batch(batch)?;
+ self.compute_aggregates()
+ }
+ Some(Err(e)) => Err(e),
+ None => {
+ self.finished = true;
+ for (_, partition_batch_state) in
self.partition_buffers.iter_mut() {
+ partition_batch_state.is_end = true;
+ }
+ self.compute_aggregates()
+ }
+ };
+ Poll::Ready(Some(result))
+ }
+
+ /// Calculates how many rows [SortedPartitionByBoundedWindowStream]
+ /// can produce as output.
+ fn calculate_n_out_row(&self) -> usize {
+ // Different window aggregators may produce results with different
rates.
+ // We produce the overall batch result with the same speed as slowest
one.
+ self.window_agg_states
+ .iter()
+ .map(|window_agg_state| {
+ // Store how many elements are generated for the current
+ // window expression:
+ let mut cur_window_expr_out_result_len = 0;
+ // We iterate over `window_agg_state`, which is an IndexMap.
+ // Iterations follow the insertion order, hence we preserve
+ // sorting when partition columns are sorted.
+ for (_, WindowState { state, .. }) in window_agg_state.iter() {
+ cur_window_expr_out_result_len += state.out_col.len();
+ // If we do not generate all results for the current
+ // partition, we do not generate results for next
+ // partition -- otherwise we will lose input ordering.
+ if state.n_row_result_missing > 0 {
+ break;
+ }
+ }
+ cur_window_expr_out_result_len
+ })
+ .into_iter()
+ .min()
+ .unwrap_or(0)
+ }
+
+ /// Prunes the sections of the record batch (for each partition)
+ /// that we no longer need to calculate the window function result.
+ fn prune_partition_batches(&mut self) -> Result<()> {
+ // Remove partitions which we know already ended (is_end flag is true).
+ // Since the retain method preserves insertion order, we still have
+ // ordering in between partitions after removal.
+ self.partition_buffers
+ .retain(|_, partition_batch_state| !partition_batch_state.is_end);
+
+ // The data in `self.partition_batches` is used by all window
expressions.
+ // Therefore, when removing from `self.partition_batches`, we need to
remove
+ // from the earliest range boundary among all window expressions.
Variable
+ // `n_prune_each_partition` fill the earliest range boundary
information for
+ // each partition. This way, we can delete the no-longer-needed
sections from
+ // `self.partition_batches`.
+ // For instance, if window frame one uses [10, 20] and window frame
two uses
+ // [5, 15]; we only prune the first 5 elements from the corresponding
record
+ // batch in `self.partition_batches`.
+
+ // Calculate how many elements to prune for each partition batch
+ let mut n_prune_each_partition: HashMap<PartitionKey, usize> =
HashMap::new();
+ for window_agg_state in self.window_agg_states.iter_mut() {
+ window_agg_state.retain(|_, WindowState { state, .. }|
!state.is_end);
+ for (partition_row, WindowState { state: value, .. }) in
window_agg_state {
+ if let Some(state) =
n_prune_each_partition.get_mut(partition_row) {
+ if value.window_frame_range.start < *state {
+ *state = value.window_frame_range.start;
+ }
+ } else {
+ n_prune_each_partition
+ .insert(partition_row.clone(),
value.window_frame_range.start);
+ }
+ }
+ }
+
+ let err = || DataFusionError::Execution("Expects to have
partition".to_string());
+ // Retract no longer needed parts during window calculations from
partition batch:
+ for (partition_row, n_prune) in n_prune_each_partition.iter() {
+ let partition_batch_state = self
+ .partition_buffers
+ .get_mut(partition_row)
+ .ok_or_else(err)?;
+ let batch = &partition_batch_state.record_batch;
+ partition_batch_state.record_batch =
+ batch.slice(*n_prune, batch.num_rows() - n_prune);
+
+ // Update state indices since we have pruned some rows from the
beginning:
+ for window_agg_state in self.window_agg_states.iter_mut() {
+ let window_state =
+ window_agg_state.get_mut(partition_row).ok_or_else(err)?;
+ let mut state = &mut window_state.state;
+ state.window_frame_range = Range {
+ start: state.window_frame_range.start - n_prune,
+ end: state.window_frame_range.end - n_prune,
+ };
+ state.last_calculated_index -= n_prune;
+ state.offset_pruned_rows += n_prune;
+ }
+ }
+ Ok(())
+ }
+
+ /// Prunes the section of the input batch whose aggregate results
+ /// are calculated and emitted.
+ fn prune_input_batch(&mut self, n_out: usize) -> Result<()> {
+ let n_to_keep = self.input_buffer.num_rows() - n_out;
+ let batch_to_keep = self
+ .input_buffer
+ .columns()
+ .iter()
+ .map(|elem| elem.slice(n_out, n_to_keep))
+ .collect::<Vec<_>>();
+ self.input_buffer =
+ RecordBatch::try_new(self.input_buffer.schema(), batch_to_keep)?;
+ Ok(())
+ }
+
+ /// Prunes emitted parts from WindowAggState `out_col` field.
+ fn prune_out_columns(&mut self, n_out: usize) -> Result<()> {
+ // We store generated columns for each window expression in the
`out_col`
+ // field of `WindowAggState`. Given how many rows are emitted, we
remove
+ // these sections from state.
+ for partition_window_agg_states in self.window_agg_states.iter_mut() {
+ let mut running_length = 0;
+ // Remove `n_out` entries from the `out_col` field of
`WindowAggState`.
+ // Preserve per partition ordering by iterating in the order of
insertion.
+ // Do not generate a result for a new partition without emitting
all results
+ // for the current partition.
+ for (
+ _,
+ WindowState {
+ state: WindowAggState { out_col, .. },
+ ..
+ },
+ ) in partition_window_agg_states
+ {
+ if running_length < n_out {
+ let n_to_del = min(out_col.len(), n_out - running_length);
+ let n_to_keep = out_col.len() - n_to_del;
+ *out_col = out_col.slice(n_to_del, n_to_keep);
+ running_length += n_to_del;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ /// Get Partition Columns
+ pub fn partition_columns(&self, batch: &RecordBatch) ->
Result<Vec<SortColumn>> {
+ self.partition_by_sort_keys
+ .iter()
+ .map(|e| e.evaluate_to_sort_column(batch))
+ .collect::<Result<Vec<_>>>()
+ }
+
+ /// evaluate the partition points given the sort columns; if the sort
columns are
+ /// empty then the result will be a single element vec of the whole column
rows.
+ fn evaluate_partition_points(
+ &self,
+ num_rows: usize,
+ partition_columns: &[SortColumn],
+ ) -> Result<Vec<Range<usize>>> {
+ Ok(if partition_columns.is_empty() {
+ vec![Range {
+ start: 0,
+ end: num_rows,
+ }]
+ } else {
+ lexicographical_partition_ranges(partition_columns)
+ .map_err(DataFusionError::ArrowError)?
+ .collect::<Vec<_>>()
+ })
+ }
+}
+
+impl RecordBatchStream for SortedPartitionByBoundedWindowStream {
+ /// Get the schema
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
+
+/// Calculates the section we can show results for expression
+fn get_aggregate_result_out_column(
+ partition_window_agg_states: &PartitionWindowAggStates,
+ len_to_show: usize,
+) -> Result<ArrayRef> {
+ let mut result = None;
+ let mut running_length = 0;
+ // We assume that iteration order is according to insertion order
+ for (
+ _,
+ WindowState {
+ state: WindowAggState { out_col, .. },
+ ..
+ },
+ ) in partition_window_agg_states
+ {
+ if running_length < len_to_show {
+ let n_to_use = min(len_to_show - running_length, out_col.len());
+ let slice_to_use = out_col.slice(0, n_to_use);
+ result = Some(match result {
+ Some(arr) => concat(&[&arr, &slice_to_use])?,
+ None => slice_to_use,
+ });
+ running_length += n_to_use;
+ } else {
+ break;
+ }
+ }
+ if running_length != len_to_show {
+ return Err(DataFusionError::Execution(format!(
+ "Generated row number should be {}, it is {}",
+ len_to_show, running_length
+ )));
+ }
+ result
+ .ok_or_else(|| DataFusionError::Execution("Should contain
something".to_string()))
+}
diff --git a/datafusion/core/src/physical_plan/windows/mod.rs
b/datafusion/core/src/physical_plan/windows/mod.rs
index bbf6c9182..2d7aa0494 100644
--- a/datafusion/core/src/physical_plan/windows/mod.rs
+++ b/datafusion/core/src/physical_plan/windows/mod.rs
@@ -39,8 +39,10 @@ use datafusion_physical_expr::window::{
use std::convert::TryInto;
use std::sync::Arc;
+mod bounded_window_agg_exec;
mod window_agg_exec;
+pub use bounded_window_agg_exec::BoundedWindowAggExec;
pub use datafusion_physical_expr::window::{
AggregateWindowExpr, BuiltInWindowExpr, WindowExpr,
};
diff --git a/datafusion/core/tests/sql/window.rs
b/datafusion/core/tests/sql/window.rs
index d3a2043f1..5ca49cff2 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -16,6 +16,8 @@
// under the License.
use super::*;
+use ::parquet::arrow::arrow_writer::ArrowWriter;
+use ::parquet::file::properties::WriterProperties;
/// for window functions without order by the first, last, and nth function
call does not make sense
#[tokio::test]
@@ -1757,11 +1759,11 @@ async fn test_window_partition_by_order_by() ->
Result<()> {
let expected = {
vec![
"ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY
[aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2
ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as
SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY
[aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS
BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(UInt8(1))]",
- " WindowAggExec: wdw=[COUNT(UInt8(1)): Ok(Field { name:
\"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]",
+ " BoundedWindowAggExec: wdw=[COUNT(UInt8(1)): Ok(Field { name:
\"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]",
" SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name:
\"c1\", index: 0 }], 2)",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c4):
Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable:
true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1))
}]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4):
Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable:
true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1))
}]",
" SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column {
name: \"c1\", index: 0 }, Column { name: \"c2\", index: 1 }], 2)",
@@ -1800,8 +1802,8 @@ async fn test_window_agg_sort_reversed_plan() ->
Result<()> {
"ProjectionExec: expr=[c9@0 as c9, SUM(aggregate_test_100.c9)
ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9
DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" GlobalLimitExec: skip=0, fetch=5",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field {
name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1))
}]",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field
{ name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9):
Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable:
true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1))
}]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9):
Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable:
true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}]",
" SortExec: [c9@0 DESC]",
]
};
@@ -1856,8 +1858,8 @@ async fn test_window_agg_sort_reversed_plan_builtin() ->
Result<()> {
"ProjectionExec: expr=[c9@0 as c9,
FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS
LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1,
FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS
FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2,
LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND
CURRENT ROW@5 as lag1, LAG(aggregate_tes [...]
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" GlobalLimitExec: skip=0, fetch=5",
- " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9):
Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32,
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame:
WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound:
Following(UInt64(1)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)):
Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\",
data_type: UInt32, nullable: true, dict_id: 0, [...]
- " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9):
Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32,
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame:
WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound:
Following(UInt64(5)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)):
Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\",
data_type: UInt32, nullable: true, dict_id: 0 [...]
+ " BoundedWindowAggExec:
wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name:
\"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1))
}, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name:
\"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32,
nullable: true, dict_ [...]
+ " BoundedWindowAggExec:
wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name:
\"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name:
\"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32,
nullable: true, dic [...]
" SortExec: [c9@0 DESC]",
]
};
@@ -1908,9 +1910,9 @@ async fn test_window_agg_sort_non_reversed_plan() ->
Result<()> {
"ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING@2 as rn1, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS
FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2]",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" GlobalLimitExec: skip=0, fetch=5",
- " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name:
\"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name:
\"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
" SortExec: [c9@0 ASC NULLS LAST]",
- " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name:
\"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field {
name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
" SortExec: [c9@0 DESC]",
]
};
@@ -1962,10 +1964,10 @@ async fn
test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> {
"ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9)
ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS
LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING@5 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9
DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1
PRECEDING AND 5 FOLLOWING@3 as sum2, ROW_NUMBER() ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWE [...]
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" GlobalLimitExec: skip=0, fetch=5",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field {
name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9):
Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable:
true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}]",
" SortExec: [c9@2 ASC NULLS LAST,c1@0 ASC NULLS LAST,c2@1
ASC NULLS LAST]",
- " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name:
\"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c9):
Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable:
true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}]",
+ " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field {
name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " BoundedWindowAggExec:
wdw=[SUM(aggregate_test_100.c9): Ok(Field { name:
\"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
" SortExec: [c9@2 DESC,c1@0 DESC]",
]
};
@@ -2099,8 +2101,8 @@ async fn
test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()>
"ProjectionExec: expr=[c9@1 as c9, SUM(aggregate_test_100.c9)
ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC
NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1,
SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING@3 as sum2]",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" GlobalLimitExec: skip=0, fetch=5",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field {
name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}]",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field
{ name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9):
Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable:
true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9):
Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable:
true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}]",
" SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]",
]
};
@@ -2154,8 +2156,8 @@ async fn test_window_agg_sort_partitionby_reversed_plan()
-> Result<()> {
"ProjectionExec: expr=[c9@1 as c9, SUM(aggregate_test_100.c9)
PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS
LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum1,
SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING@2 as sum2]",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" GlobalLimitExec: skip=0, fetch=5",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field {
name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1))
}]",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field
{ name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9):
Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable:
true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1))
}]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9):
Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable:
true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}]",
" SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]",
]
};
@@ -2351,3 +2353,251 @@ async fn
test_window_agg_sort_orderby_reversed_partitionby_reversed_plan() -> Re
Ok(())
}
+
+fn write_test_data_to_parquet(tmpdir: &TempDir, n_file: usize) -> Result<()> {
+ let ts_field = Field::new("ts", DataType::Int32, false);
+ let inc_field = Field::new("inc_col", DataType::Int32, false);
+ let desc_field = Field::new("desc_col", DataType::Int32, false);
+
+ let schema = Arc::new(Schema::new(vec![ts_field, inc_field, desc_field]));
+
+ let batch = RecordBatch::try_new(
+ schema,
+ vec![
+ Arc::new(Int32Array::from_slice([
+ 1, 1, 5, 9, 10, 11, 16, 21, 22, 26, 26, 28, 31, 33, 38, 42,
47, 51, 53,
+ 53, 58, 63, 67, 68, 70, 72, 72, 76, 81, 85, 86, 88, 91, 96,
97, 98, 100,
+ 101, 102, 104, 104, 108, 112, 113, 113, 114, 114, 117, 122,
126, 131,
+ 131, 136, 136, 136, 139, 141, 146, 147, 147, 152, 154, 159,
161, 163,
+ 164, 167, 172, 173, 177, 180, 185, 186, 191, 195, 195, 199,
203, 207,
+ 210, 213, 218, 221, 224, 226, 230, 232, 235, 238, 238, 239,
244, 245,
+ 247, 250, 254, 258, 262, 264, 264,
+ ])),
+ Arc::new(Int32Array::from_slice([
+ 1, 5, 10, 15, 20, 21, 26, 29, 30, 33, 37, 40, 43, 44, 45, 49,
51, 53, 58,
+ 61, 65, 70, 75, 78, 83, 88, 90, 91, 95, 97, 100, 105, 109,
111, 115, 119,
+ 120, 124, 126, 129, 131, 135, 140, 143, 144, 147, 148, 149,
151, 155,
+ 156, 159, 160, 163, 165, 170, 172, 177, 181, 182, 186, 187,
192, 196,
+ 197, 199, 203, 207, 209, 213, 214, 216, 219, 221, 222, 225,
226, 231,
+ 236, 237, 242, 245, 247, 248, 253, 254, 259, 261, 266, 269,
272, 275,
+ 278, 283, 286, 289, 291, 296, 301, 305,
+ ])),
+ Arc::new(Int32Array::from_slice([
+ 100, 98, 93, 91, 86, 84, 81, 77, 75, 71, 70, 69, 64, 62, 59,
55, 50, 45,
+ 41, 40, 39, 36, 31, 28, 23, 22, 17, 13, 10, 6, 5, 2, 1, -1,
-4, -5, -6,
+ -8, -12, -16, -17, -19, -24, -25, -29, -34, -37, -42, -47,
-48, -49, -53,
+ -57, -58, -61, -65, -67, -68, -71, -73, -75, -76, -78, -83,
-87, -91,
+ -95, -98, -101, -105, -106, -111, -114, -116, -120, -125,
-128, -129,
+ -134, -139, -142, -143, -146, -150, -154, -158, -163, -168,
-172, -176,
+ -181, -184, -189, -193, -196, -201, -203, -208, -210, -213,
+ ])),
+ ],
+ )?;
+ let n_chunk = batch.num_rows() / n_file;
+ for i in 0..n_file {
+ let target_file = tmpdir.path().join(format!("{}.parquet", i));
+ let file = File::create(target_file).unwrap();
+ // Default writer properties
+ let props = WriterProperties::builder().build();
+ let chunks_start = i * n_chunk;
+ let cur_batch = batch.slice(chunks_start, n_chunk);
+ // let chunks_end = chunks_start + n_chunk;
+ let mut writer =
+ ArrowWriter::try_new(file, cur_batch.schema(),
Some(props)).unwrap();
+
+ writer.write(&cur_batch).expect("Writing batch");
+
+ // writer must be closed to write footer
+ writer.close().unwrap();
+ }
+ Ok(())
+}
+
+async fn get_test_context(tmpdir: &TempDir) -> Result<SessionContext> {
+ let session_config = SessionConfig::new().with_target_partitions(1);
+ let ctx = SessionContext::with_config(session_config);
+
+ let parquet_read_options = ParquetReadOptions::default();
+ // The sort order is specified (not actually correct in this case)
+ let file_sort_order = [col("ts")]
+ .into_iter()
+ .map(|e| {
+ let ascending = true;
+ let nulls_first = false;
+ e.sort(ascending, nulls_first)
+ })
+ .collect::<Vec<_>>();
+
+ let options_sort = parquet_read_options
+ .to_listing_options(&ctx.copied_config())
+ .with_file_sort_order(Some(file_sort_order));
+
+ write_test_data_to_parquet(tmpdir, 1)?;
+ let provided_schema = None;
+ let sql_definition = None;
+ ctx.register_listing_table(
+ "annotated_data",
+ tmpdir.path().to_string_lossy(),
+ options_sort.clone(),
+ provided_schema,
+ sql_definition,
+ )
+ .await
+ .unwrap();
+ Ok(ctx)
+}
+
+mod tests {
+ use super::*;
+
+ #[tokio::test]
+ async fn test_source_sorted_aggregate() -> Result<()> {
+ let tmpdir = TempDir::new().unwrap();
+ let ctx = get_test_context(&tmpdir).await?;
+
+ let sql = "SELECT
+ SUM(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING AND 1
FOLLOWING) as sum1,
+ SUM(desc_col) OVER(ORDER BY ts RANGE BETWEEN 5 PRECEDING AND 1
FOLLOWING) as sum2,
+ SUM(inc_col) OVER(ORDER BY ts ROWS BETWEEN 1 PRECEDING AND 10
FOLLOWING) as sum3,
+ MIN(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING AND 1
FOLLOWING) as min1,
+ MIN(desc_col) OVER(ORDER BY ts RANGE BETWEEN 5 PRECEDING AND 1
FOLLOWING) as min2,
+ MIN(inc_col) OVER(ORDER BY ts ROWS BETWEEN 1 PRECEDING AND 10
FOLLOWING) as min3,
+ MAX(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING AND 1
FOLLOWING) as max1,
+ MAX(desc_col) OVER(ORDER BY ts RANGE BETWEEN 5 PRECEDING AND 1
FOLLOWING) as max2,
+ MAX(inc_col) OVER(ORDER BY ts ROWS BETWEEN 1 PRECEDING AND 10
FOLLOWING) as max3,
+ COUNT(*) OVER(ORDER BY ts RANGE BETWEEN 4 PRECEDING AND 8
FOLLOWING) as cnt1,
+ COUNT(*) OVER(ORDER BY ts ROWS BETWEEN 8 PRECEDING AND 1
FOLLOWING) as cnt2,
+ SUM(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING AND 4
FOLLOWING) as sumr1,
+ SUM(desc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING AND
8 FOLLOWING) as sumr2,
+ SUM(desc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING) as sumr3,
+ MIN(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 10 PRECEDING AND
1 FOLLOWING) as minr1,
+ MIN(desc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 5 PRECEDING AND
1 FOLLOWING) as minr2,
+ MIN(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 1 PRECEDING AND 10
FOLLOWING) as minr3,
+ MAX(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 10 PRECEDING AND
1 FOLLOWING) as maxr1,
+ MAX(desc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 5 PRECEDING AND
1 FOLLOWING) as maxr2,
+ MAX(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 1 PRECEDING AND 10
FOLLOWING) as maxr3,
+ COUNT(*) OVER(ORDER BY ts DESC RANGE BETWEEN 6 PRECEDING AND 2
FOLLOWING) as cntr1,
+ COUNT(*) OVER(ORDER BY ts DESC ROWS BETWEEN 8 PRECEDING AND 1
FOLLOWING) as cntr2,
+ SUM(desc_col) OVER(ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as
sum4,
+ COUNT(*) OVER(ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as cnt3
+ FROM annotated_data
+ ORDER BY inc_col DESC
+ LIMIT 5
+ ";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted =
displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, sum3@2
as sum3, min1@3 as min1, min2@4 as min2, min3@5 as min3, max1@6 as max1, max2@7
as max2, max3@8 as max3, cnt1@9 as cnt1, cnt2@10 as cnt2, sumr1@11 as sumr1,
sumr2@12 as sumr2, sumr3@13 as sumr3, minr1@14 as minr1, minr2@15 as minr2,
minr3@16 as minr3, maxr1@17 as maxr1, maxr2@18 as maxr2, maxr3@19 as maxr3,
cntr1@20 as cntr1, cntr2@21 as cntr2, sum4@22 as sum4, cnt3@23 as cnt3]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " SortExec: [inc_col@24 DESC]",
+ " ProjectionExec: expr=[SUM(annotated_data.inc_col) ORDER
BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1
FOLLOWING@14 as sum1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts
ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@15 as sum2,
SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS
BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as sum3, MIN(annotated_data.inc_col)
ORDER BY [annotated_data.ts ASC NULLS L [...]
+ " BoundedWindowAggExec:
wdw=[SUM(annotated_data.desc_col): Ok(Field { name:
\"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) },
COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64,
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame:
WindowFrame { units [...]
+ " BoundedWindowAggExec:
wdw=[SUM(annotated_data.inc_col): Ok(Field { name:
\"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range,
start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) },
SUM(annotated_data.desc_col): Ok(Field { name:
\"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), [...]
+ " BoundedWindowAggExec:
wdw=[SUM(annotated_data.inc_col): Ok(Field { name:
\"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range,
start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)) },
SUM(annotated_data.desc_col): Ok(Field { name:
\"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), [...]
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+
"+------+------+------+------+------+------+------+------+------+------+------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+------+",
+ "| sum1 | sum2 | sum3 | min1 | min2 | min3 | max1 | max2 | max3 |
cnt1 | cnt2 | sumr1 | sumr2 | sumr3 | minr1 | minr2 | minr3 | maxr1 | maxr2 |
maxr3 | cntr1 | cntr2 | sum4 | cnt3 |",
+
"+------+------+------+------+------+------+------+------+------+------+------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+------+",
+ "| 1482 | -631 | 606 | 289 | -213 | 301 | 305 | -208 | 305 |
3 | 9 | 902 | -834 | -1231 | 301 | -213 | 269 | 305 | -210 |
305 | 3 | 2 | -1797 | 9 |",
+ "| 1482 | -631 | 902 | 289 | -213 | 296 | 305 | -208 | 305 |
3 | 10 | 902 | -834 | -1424 | 301 | -213 | 266 | 305 | -210 |
305 | 3 | 3 | -1978 | 10 |",
+ "| 876 | -411 | 1193 | 289 | -208 | 291 | 296 | -203 | 305 |
4 | 10 | 587 | -612 | -1400 | 296 | -213 | 261 | 305 | -208 |
301 | 3 | 4 | -1941 | 10 |",
+ "| 866 | -404 | 1482 | 286 | -203 | 289 | 291 | -201 | 305 |
5 | 10 | 580 | -600 | -1374 | 291 | -208 | 259 | 305 | -203 |
296 | 4 | 5 | -1903 | 10 |",
+ "| 1411 | -397 | 1768 | 275 | -201 | 286 | 289 | -196 | 305 |
4 | 10 | 575 | -590 | -1347 | 289 | -203 | 254 | 305 | -201 |
291 | 2 | 6 | -1863 | 10 |",
+
"+------+------+------+------+------+------+------+------+------+------+------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_source_sorted_builtin() -> Result<()> {
+ let tmpdir = TempDir::new().unwrap();
+ let ctx = get_test_context(&tmpdir).await?;
+
+ let sql = "SELECT
+ FIRST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING
and 1 FOLLOWING) as fv1,
+ FIRST_VALUE(inc_col) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING
and 1 FOLLOWING) as fv2,
+ LAST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING
and 1 FOLLOWING) as lv1,
+ LAST_VALUE(inc_col) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and
1 FOLLOWING) as lv2,
+ NTH_VALUE(inc_col, 5) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING
and 1 FOLLOWING) as nv1,
+ NTH_VALUE(inc_col, 5) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING
and 1 FOLLOWING) as nv2,
+ ROW_NUMBER() OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING and 10
FOLLOWING) AS rn1,
+ ROW_NUMBER() OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1
FOLLOWING) as rn2,
+ RANK() OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING and 10
FOLLOWING) AS rank1,
+ RANK() OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING)
as rank2,
+ DENSE_RANK() OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING and 10
FOLLOWING) AS dense_rank1,
+ DENSE_RANK() OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1
FOLLOWING) as dense_rank2,
+ LAG(inc_col, 1, 1001) OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING
and 10 FOLLOWING) AS lag1,
+ LAG(inc_col, 2, 1002) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING
and 1 FOLLOWING) as lag2,
+ LEAD(inc_col, -1, 1001) OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING
and 10 FOLLOWING) AS lead1,
+ LEAD(inc_col, 4, 1004) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING
and 1 FOLLOWING) as lead2,
+ FIRST_VALUE(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 10
PRECEDING and 1 FOLLOWING) as fvr1,
+ FIRST_VALUE(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 10
PRECEDING and 1 FOLLOWING) as fvr2,
+ LAST_VALUE(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 10
PRECEDING and 1 FOLLOWING) as lvr1,
+ LAST_VALUE(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 10
PRECEDING and 1 FOLLOWING) as lvr2,
+ LAG(inc_col, 1, 1001) OVER(ORDER BY ts DESC RANGE BETWEEN 1
PRECEDING and 10 FOLLOWING) AS lagr1,
+ LAG(inc_col, 2, 1002) OVER(ORDER BY ts DESC ROWS BETWEEN 10
PRECEDING and 1 FOLLOWING) as lagr2,
+ LEAD(inc_col, -1, 1001) OVER(ORDER BY ts DESC RANGE BETWEEN 1
PRECEDING and 10 FOLLOWING) AS leadr1,
+ LEAD(inc_col, 4, 1004) OVER(ORDER BY ts DESC ROWS BETWEEN 10
PRECEDING and 1 FOLLOWING) as leadr2
+ FROM annotated_data
+ ORDER BY ts DESC
+ LIMIT 5
+ ";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted =
displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[fv1@0 as fv1, fv2@1 as fv2, lv1@2 as
lv1, lv2@3 as lv2, nv1@4 as nv1, nv2@5 as nv2, rn1@6 as rn1, rn2@7 as rn2,
rank1@8 as rank1, rank2@9 as rank2, dense_rank1@10 as dense_rank1,
dense_rank2@11 as dense_rank2, lag1@12 as lag1, lag2@13 as lag2, lead1@14 as
lead1, lead2@15 as lead2, fvr1@16 as fvr1, fvr2@17 as fvr2, lvr1@18 as lvr1,
lvr2@19 as lvr2, lagr1@20 as lagr1, lagr2@21 as lagr2, leadr1@22 as leadr1,
leadr2@23 as leadr2]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " SortExec: [ts@24 DESC]",
+ " ProjectionExec:
expr=[FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS
LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1,
FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST]
ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2,
LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST]
RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1,
LAST_VALUE(annotated_data.inc_col) ORDER BY [an [...]
+ " BoundedWindowAggExec:
wdw=[FIRST_VALUE(annotated_data.inc_col): Ok(Field { name:
\"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1))
}, FIRST_VALUE(annotated_data.inc_col): Ok(Field { name:
\"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true,
dict_id: 0, dict_is_order [...]
+ " BoundedWindowAggExec:
wdw=[FIRST_VALUE(annotated_data.inc_col): Ok(Field { name:
\"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10))
}, FIRST_VALUE(annotated_data.inc_col): Ok(Field { name:
\"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true,
dict_id: 0, dict_is_ord [...]
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+
"+-----+-----+-----+-----+-----+-----+-----+-----+-------+-------+-------------+-------------+------+------+-------+-------+------+------+------+------+-------+-------+--------+--------+",
+ "| fv1 | fv2 | lv1 | lv2 | nv1 | nv2 | rn1 | rn2 | rank1 | rank2 |
dense_rank1 | dense_rank2 | lag1 | lag2 | lead1 | lead2 | fvr1 | fvr2 | lvr1 |
lvr2 | lagr1 | lagr2 | leadr1 | leadr2 |",
+
"+-----+-----+-----+-----+-----+-----+-----+-----+-------+-------+-------------+-------------+------+------+-------+-------+------+------+------+------+-------+-------+--------+--------+",
+ "| 289 | 266 | 305 | 305 | 305 | 278 | 99 | 99 | 99 | 99 |
86 | 86 | 296 | 291 | 296 | 1004 | 305 | 305 | 301 |
296 | 305 | 1002 | 305 | 286 |",
+ "| 289 | 269 | 305 | 305 | 305 | 283 | 100 | 100 | 99 | 99 |
86 | 86 | 301 | 296 | 301 | 1004 | 305 | 305 | 301 |
301 | 1001 | 1002 | 1001 | 289 |",
+ "| 289 | 261 | 296 | 301 | | 275 | 98 | 98 | 98 | 98 |
85 | 85 | 291 | 289 | 291 | 1004 | 305 | 305 | 296 |
291 | 301 | 305 | 301 | 283 |",
+ "| 286 | 259 | 291 | 296 | | 272 | 97 | 97 | 97 | 97 |
84 | 84 | 289 | 286 | 289 | 1004 | 305 | 305 | 291 |
289 | 296 | 301 | 296 | 278 |",
+ "| 275 | 254 | 289 | 291 | 289 | 269 | 96 | 96 | 96 | 96 |
83 | 83 | 286 | 283 | 286 | 305 | 305 | 305 | 289 |
286 | 291 | 296 | 291 | 275 |",
+
"+-----+-----+-----+-----+-----+-----+-----+-----+-------+-------+-------------+-------------+------+------+-------+-------+------+------+------+------+-------+-------+--------+--------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+ }
+}
diff --git a/datafusion/core/tests/window_fuzz.rs
b/datafusion/core/tests/window_fuzz.rs
new file mode 100644
index 000000000..471484af2
--- /dev/null
+++ b/datafusion/core/tests/window_fuzz.rs
@@ -0,0 +1,385 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use arrow::array::{ArrayRef, Int32Array};
+use arrow::compute::{concat_batches, SortOptions};
+use arrow::record_batch::RecordBatch;
+use arrow::util::pretty::pretty_format_batches;
+use hashbrown::HashMap;
+use rand::rngs::StdRng;
+use rand::{Rng, SeedableRng};
+
+use datafusion::physical_plan::collect;
+use datafusion::physical_plan::memory::MemoryExec;
+use datafusion::physical_plan::windows::{
+ create_window_expr, BoundedWindowAggExec, WindowAggExec,
+};
+use datafusion_expr::{
+ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound,
+ WindowFrameUnits, WindowFunction,
+};
+
+use datafusion::prelude::{SessionConfig, SessionContext};
+use datafusion_common::ScalarValue;
+use datafusion_physical_expr::expressions::{col, lit};
+use datafusion_physical_expr::PhysicalSortExpr;
+use test_utils::add_empty_batches;
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
+ async fn single_order_by_test() {
+ let n = 100;
+ let distincts = vec![1, 100];
+ for distinct in distincts {
+ let mut handles = Vec::new();
+ for i in 1..n {
+ let job = tokio::spawn(run_window_test(
+ make_staggered_batches::<true>(1000, distinct, i),
+ i,
+ vec!["a"],
+ vec![],
+ ));
+ handles.push(job);
+ }
+ for job in handles {
+ job.await.unwrap();
+ }
+ }
+ }
+
+ #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
+ async fn order_by_with_partition_test() {
+ let n = 100;
+ let distincts = vec![1, 100];
+ for distinct in distincts {
+ // since we have sorted pairs (a,b) to not violate per partition
soring
+ // partition should be field a, order by should be field b
+ let mut handles = Vec::new();
+ for i in 1..n {
+ let job = tokio::spawn(run_window_test(
+ make_staggered_batches::<true>(1000, distinct, i),
+ i,
+ vec!["b"],
+ vec!["a"],
+ ));
+ handles.push(job);
+ }
+ for job in handles {
+ job.await.unwrap();
+ }
+ }
+ }
+}
+
+/// Perform batch and running window same input
+/// and verify outputs of `WindowAggExec` and `BoundedWindowAggExec` are equal
+async fn run_window_test(
+ input1: Vec<RecordBatch>,
+ random_seed: u64,
+ orderby_columns: Vec<&str>,
+ partition_by_columns: Vec<&str>,
+) {
+ let mut rng = StdRng::seed_from_u64(random_seed);
+ let schema = input1[0].schema();
+ let mut args = vec![col("x", &schema).unwrap()];
+ let mut window_fn_map = HashMap::new();
+ // HashMap values consists of tuple first element is WindowFunction,
second is additional argument
+ // window function requires if any. For most of the window functions
additional argument is empty
+ window_fn_map.insert(
+ "sum",
+ (
+ WindowFunction::AggregateFunction(AggregateFunction::Sum),
+ vec![],
+ ),
+ );
+ window_fn_map.insert(
+ "count",
+ (
+ WindowFunction::AggregateFunction(AggregateFunction::Count),
+ vec![],
+ ),
+ );
+ window_fn_map.insert(
+ "min",
+ (
+ WindowFunction::AggregateFunction(AggregateFunction::Min),
+ vec![],
+ ),
+ );
+ window_fn_map.insert(
+ "max",
+ (
+ WindowFunction::AggregateFunction(AggregateFunction::Max),
+ vec![],
+ ),
+ );
+ window_fn_map.insert(
+ "row_number",
+ (
+
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber),
+ vec![],
+ ),
+ );
+ window_fn_map.insert(
+ "rank",
+ (
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Rank),
+ vec![],
+ ),
+ );
+ window_fn_map.insert(
+ "first_value",
+ (
+
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue),
+ vec![],
+ ),
+ );
+ window_fn_map.insert(
+ "last_value",
+ (
+
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue),
+ vec![],
+ ),
+ );
+ window_fn_map.insert(
+ "nth_value",
+ (
+
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::NthValue),
+ vec![lit(ScalarValue::Int64(Some(rng.gen_range(1..10))))],
+ ),
+ );
+ window_fn_map.insert(
+ "lead",
+ (
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead),
+ vec![
+ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
+ lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))),
+ ],
+ ),
+ );
+ window_fn_map.insert(
+ "lag",
+ (
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag),
+ vec![
+ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
+ lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))),
+ ],
+ ),
+ );
+
+ let session_config = SessionConfig::new().with_batch_size(50);
+ let ctx = SessionContext::with_config(session_config);
+ let rand_fn_idx = rng.gen_range(0..window_fn_map.len());
+ let fn_name = window_fn_map.keys().collect::<Vec<_>>()[rand_fn_idx];
+ let (window_fn, new_args) =
window_fn_map.values().collect::<Vec<_>>()[rand_fn_idx];
+ for new_arg in new_args {
+ args.push(new_arg.clone());
+ }
+ let preceding = rng.gen_range(0..50);
+ let following = rng.gen_range(0..50);
+ let rand_num = rng.gen_range(0..3);
+ let units = if rand_num < 1 {
+ WindowFrameUnits::Range
+ } else if rand_num < 2 {
+ WindowFrameUnits::Rows
+ } else {
+ // For now we do not support GROUPS in BoundedWindowAggExec
implementation
+ // TODO: once GROUPS handling is available, use
WindowFrameUnits::GROUPS in randomized tests also.
+ WindowFrameUnits::Range
+ };
+ let window_frame = match units {
+ // In range queries window frame boundaries should match column type
+ WindowFrameUnits::Range => WindowFrame {
+ units,
+ start_bound:
WindowFrameBound::Preceding(ScalarValue::Int32(Some(preceding))),
+ end_bound:
WindowFrameBound::Following(ScalarValue::Int32(Some(following))),
+ },
+ // In window queries, window frame boundary should be Uint64
+ WindowFrameUnits::Rows => WindowFrame {
+ units,
+ start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(
+ preceding as u64,
+ ))),
+ end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(
+ following as u64,
+ ))),
+ },
+ // Once GROUPS support is added construct window frame for this case
also
+ _ => todo!(),
+ };
+ let mut orderby_exprs = vec![];
+ for column in orderby_columns {
+ orderby_exprs.push(PhysicalSortExpr {
+ expr: col(column, &schema).unwrap(),
+ options: SortOptions::default(),
+ })
+ }
+ let mut partitionby_exprs = vec![];
+ for column in partition_by_columns {
+ partitionby_exprs.push(col(column, &schema).unwrap());
+ }
+ let mut sort_keys = vec![];
+ for partition_by_expr in &partitionby_exprs {
+ sort_keys.push(PhysicalSortExpr {
+ expr: partition_by_expr.clone(),
+ options: SortOptions::default(),
+ })
+ }
+ for order_by_expr in &orderby_exprs {
+ sort_keys.push(order_by_expr.clone())
+ }
+
+ let concat_input_record = concat_batches(&schema, &input1).unwrap();
+ let exec1 = Arc::new(
+ MemoryExec::try_new(&[vec![concat_input_record]], schema.clone(),
None).unwrap(),
+ );
+ let usual_window_exec = Arc::new(
+ WindowAggExec::try_new(
+ vec![create_window_expr(
+ window_fn,
+ fn_name.to_string(),
+ &args,
+ &partitionby_exprs,
+ &orderby_exprs,
+ Arc::new(window_frame.clone()),
+ schema.as_ref(),
+ )
+ .unwrap()],
+ exec1,
+ schema.clone(),
+ vec![],
+ Some(sort_keys.clone()),
+ )
+ .unwrap(),
+ );
+ let exec2 =
+ Arc::new(MemoryExec::try_new(&[input1.clone()], schema.clone(),
None).unwrap());
+ let running_window_exec = Arc::new(
+ BoundedWindowAggExec::try_new(
+ vec![create_window_expr(
+ window_fn,
+ fn_name.to_string(),
+ &args,
+ &partitionby_exprs,
+ &orderby_exprs,
+ Arc::new(window_frame.clone()),
+ schema.as_ref(),
+ )
+ .unwrap()],
+ exec2,
+ schema.clone(),
+ vec![],
+ Some(sort_keys),
+ )
+ .unwrap(),
+ );
+
+ let task_ctx = ctx.task_ctx();
+ let collected_usual = collect(usual_window_exec,
task_ctx.clone()).await.unwrap();
+
+ let collected_running = collect(running_window_exec, task_ctx.clone())
+ .await
+ .unwrap();
+ // compare
+ let usual_formatted =
pretty_format_batches(&collected_usual).unwrap().to_string();
+ let running_formatted = pretty_format_batches(&collected_running)
+ .unwrap()
+ .to_string();
+
+ let mut usual_formatted_sorted: Vec<&str> =
usual_formatted.trim().lines().collect();
+ usual_formatted_sorted.sort_unstable();
+
+ let mut running_formatted_sorted: Vec<&str> =
+ running_formatted.trim().lines().collect();
+ running_formatted_sorted.sort_unstable();
+ for (i, (usual_line, running_line)) in usual_formatted_sorted
+ .iter()
+ .zip(&running_formatted_sorted)
+ .enumerate()
+ {
+ assert_eq!(
+ (i, usual_line),
+ (i, running_line),
+ "Inconsistent result for window_fn: {:?}, args:{:?}",
+ window_fn,
+ args
+ );
+ }
+}
+
+/// Return randomly sized record batches with:
+/// two sorted int32 columns 'a', 'b' ranged from 0..len / DISTINCT as columns
+/// two random int32 columns 'x', 'y' as other columns
+fn make_staggered_batches<const STREAM: bool>(
+ len: usize,
+ distinct: usize,
+ random_seed: u64,
+) -> Vec<RecordBatch> {
+ // use a random number generator to pick a random sized output
+ let mut rng = StdRng::seed_from_u64(random_seed);
+ let mut input12: Vec<(i32, i32)> = vec![(0, 0); len];
+ let mut input3: Vec<i32> = vec![0; len];
+ let mut input4: Vec<i32> = vec![0; len];
+ input12.iter_mut().for_each(|v| {
+ *v = (
+ rng.gen_range(0..(len / distinct)) as i32,
+ rng.gen_range(0..(len / distinct)) as i32,
+ )
+ });
+ rng.fill(&mut input3[..]);
+ rng.fill(&mut input4[..]);
+ input12.sort();
+ let input1 =
Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.0));
+ let input2 =
Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.1));
+ let input3 = Int32Array::from_iter_values(input3.into_iter());
+ let input4 = Int32Array::from_iter_values(input4.into_iter());
+
+ // split into several record batches
+ let mut remainder = RecordBatch::try_from_iter(vec![
+ ("a", Arc::new(input1) as ArrayRef),
+ ("b", Arc::new(input2) as ArrayRef),
+ ("x", Arc::new(input3) as ArrayRef),
+ ("y", Arc::new(input4) as ArrayRef),
+ ])
+ .unwrap();
+
+ let mut batches = vec![];
+ if STREAM {
+ while remainder.num_rows() > 0 {
+ let batch_size = rng.gen_range(0..50);
+ if remainder.num_rows() < batch_size {
+ break;
+ }
+ batches.push(remainder.slice(0, batch_size));
+ remainder = remainder.slice(batch_size, remainder.num_rows() -
batch_size);
+ }
+ } else {
+ while remainder.num_rows() > 0 {
+ let batch_size = rng.gen_range(0..remainder.num_rows() + 1);
+ batches.push(remainder.slice(0, batch_size));
+ remainder = remainder.slice(batch_size, remainder.num_rows() -
batch_size);
+ }
+ }
+ add_empty_batches(batches, &mut rng)
+}
diff --git a/datafusion/physical-expr/Cargo.toml
b/datafusion/physical-expr/Cargo.toml
index 094d233a9..5aede03fd 100644
--- a/datafusion/physical-expr/Cargo.toml
+++ b/datafusion/physical-expr/Cargo.toml
@@ -51,6 +51,7 @@ datafusion-expr = { path = "../expr", version = "15.0.0" }
datafusion-row = { path = "../row", version = "15.0.0" }
half = { version = "2.1", default-features = false }
hashbrown = { version = "0.13", features = ["raw"] }
+indexmap = "1.9.2"
itertools = { version = "0.10", features = ["use_std"] }
lazy_static = { version = "^1.4.0" }
md-5 = { version = "^0.10.0", optional = true }
diff --git a/datafusion/physical-expr/src/aggregate/count.rs
b/datafusion/physical-expr/src/aggregate/count.rs
index 813952117..8ccf87ac2 100644
--- a/datafusion/physical-expr/src/aggregate/count.rs
+++ b/datafusion/physical-expr/src/aggregate/count.rs
@@ -98,6 +98,10 @@ impl AggregateExpr for Count {
true
}
+ fn supports_bounded_execution(&self) -> bool {
+ true
+ }
+
fn create_row_accumulator(
&self,
start_index: usize,
diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs
b/datafusion/physical-expr/src/aggregate/min_max.rs
index a7bd6c360..bf4fd0868 100644
--- a/datafusion/physical-expr/src/aggregate/min_max.rs
+++ b/datafusion/physical-expr/src/aggregate/min_max.rs
@@ -62,7 +62,7 @@ fn min_max_aggregate_data_type(input_type: DataType) ->
DataType {
}
/// MAX aggregate expression
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Max {
name: String,
data_type: DataType,
@@ -124,6 +124,10 @@ impl AggregateExpr for Max {
is_row_accumulator_support_dtype(&self.data_type)
}
+ fn supports_bounded_execution(&self) -> bool {
+ true
+ }
+
fn create_row_accumulator(
&self,
start_index: usize,
@@ -134,6 +138,10 @@ impl AggregateExpr for Max {
)))
}
+ fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
+ Some(Arc::new(self.clone()))
+ }
+
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(SlidingMaxAccumulator::try_new(&self.data_type)?))
}
@@ -672,7 +680,7 @@ impl RowAccumulator for MaxRowAccumulator {
}
/// MIN aggregate expression
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Min {
name: String,
data_type: DataType,
@@ -734,6 +742,10 @@ impl AggregateExpr for Min {
is_row_accumulator_support_dtype(&self.data_type)
}
+ fn supports_bounded_execution(&self) -> bool {
+ true
+ }
+
fn create_row_accumulator(
&self,
start_index: usize,
@@ -744,6 +756,10 @@ impl AggregateExpr for Min {
)))
}
+ fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
+ Some(Arc::new(self.clone()))
+ }
+
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(SlidingMinAccumulator::try_new(&self.data_type)?))
}
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs
b/datafusion/physical-expr/src/aggregate/mod.rs
index 528a5cc73..c42a5c03b 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -88,6 +88,12 @@ pub trait AggregateExpr: Send + Sync + Debug {
false
}
+ /// Specifies whether this aggregate function can run using bounded memory.
+ /// Any accumulator returning "true" needs to implement `retract_batch`.
+ fn supports_bounded_execution(&self) -> bool {
+ false
+ }
+
/// RowAccumulator to access/update row-based aggregation state in-place.
/// Currently, row accumulator only supports states of fixed-sized type.
///
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs
b/datafusion/physical-expr/src/aggregate/sum.rs
index 5de9a9296..8f78abfd5 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -113,6 +113,10 @@ impl AggregateExpr for Sum {
is_row_accumulator_support_dtype(&self.data_type)
}
+ fn supports_bounded_execution(&self) -> bool {
+ true
+ }
+
fn create_row_accumulator(
&self,
start_index: usize,
diff --git a/datafusion/physical-expr/src/window/aggregate.rs
b/datafusion/physical-expr/src/window/aggregate.rs
index 5c46f38f2..df61e7cc8 100644
--- a/datafusion/physical-expr/src/window/aggregate.rs
+++ b/datafusion/physical-expr/src/window/aggregate.rs
@@ -29,7 +29,7 @@ use arrow::{array::ArrayRef, datatypes::Field};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
-use datafusion_expr::WindowFrame;
+use datafusion_expr::{WindowFrame, WindowFrameUnits};
use crate::window::window_expr::reverse_order_bys;
use crate::window::SlidingAggregateWindowExpr;
@@ -162,4 +162,12 @@ impl WindowExpr for AggregateWindowExpr {
}
})
}
+
+ fn uses_bounded_memory(&self) -> bool {
+ // NOTE: Currently, groups queries do not support the bounded memory
variant.
+ self.aggregate.supports_bounded_execution()
+ && !self.window_frame.start_bound.is_unbounded()
+ && !self.window_frame.end_bound.is_unbounded()
+ && !matches!(self.window_frame.units, WindowFrameUnits::Groups)
+ }
}
diff --git a/datafusion/physical-expr/src/window/built_in.rs
b/datafusion/physical-expr/src/window/built_in.rs
index 9804432b2..f0484b790 100644
--- a/datafusion/physical-expr/src/window/built_in.rs
+++ b/datafusion/physical-expr/src/window/built_in.rs
@@ -20,14 +20,19 @@
use super::window_frame_state::WindowFrameContext;
use super::BuiltInWindowFunctionExpr;
use super::WindowExpr;
-use crate::window::window_expr::reverse_order_bys;
+use crate::window::window_expr::{
+ reverse_order_bys, BuiltinWindowState, WindowFn, WindowFunctionState,
+};
+use crate::window::{
+ PartitionBatches, PartitionWindowAggStates, WindowAggState, WindowState,
+};
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
-use arrow::compute::SortOptions;
+use arrow::compute::{concat, SortOptions};
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};
-use datafusion_common::Result;
use datafusion_common::ScalarValue;
-use datafusion_expr::WindowFrame;
+use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::{WindowFrame, WindowFrameUnits};
use std::any::Any;
use std::sync::Arc;
@@ -91,7 +96,7 @@ impl WindowExpr for BuiltInWindowExpr {
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let evaluator = self.expr.create_evaluator()?;
let num_rows = batch.num_rows();
- if evaluator.uses_window_frame() {
+ if self.expr.uses_window_frame() {
let sort_options: Vec<SortOptions> =
self.order_by.iter().map(|o| o.options).collect();
let mut row_wise_results = vec![];
@@ -122,6 +127,102 @@ impl WindowExpr for BuiltInWindowExpr {
}
}
+ /// Evaluate the window function against the batch. This function
facilitates
+ /// stateful, bounded-memory implementations.
+ fn evaluate_stateful(
+ &self,
+ partition_batches: &PartitionBatches,
+ window_agg_state: &mut PartitionWindowAggStates,
+ ) -> Result<()> {
+ let field = self.expr.field()?;
+ let out_type = field.data_type();
+ let sort_options = self.order_by.iter().map(|o|
o.options).collect::<Vec<_>>();
+ for (partition_row, partition_batch_state) in partition_batches.iter()
{
+ if !window_agg_state.contains_key(partition_row) {
+ let evaluator = self.expr.create_evaluator()?;
+ window_agg_state.insert(
+ partition_row.clone(),
+ WindowState {
+ state: WindowAggState::new(
+ out_type,
+ WindowFunctionState::BuiltinWindowState(
+ BuiltinWindowState::Default,
+ ),
+ )?,
+ window_fn: WindowFn::Builtin(evaluator),
+ },
+ );
+ };
+ let window_state =
+ window_agg_state.get_mut(partition_row).ok_or_else(|| {
+ DataFusionError::Execution("Cannot find state".to_string())
+ })?;
+ let evaluator = match &mut window_state.window_fn {
+ WindowFn::Builtin(evaluator) => evaluator,
+ _ => unreachable!(),
+ };
+ let mut state = &mut window_state.state;
+ state.is_end = partition_batch_state.is_end;
+
+ let (values, order_bys) =
+ self.get_values_orderbys(&partition_batch_state.record_batch)?;
+
+ // We iterate on each row to perform a running calculation.
+ let num_rows = partition_batch_state.record_batch.num_rows();
+ let mut last_range = state.window_frame_range.clone();
+ let mut window_frame_ctx =
WindowFrameContext::new(&self.window_frame);
+ let sort_partition_points = if evaluator.include_rank() {
+ let columns =
self.sort_columns(&partition_batch_state.record_batch)?;
+ self.evaluate_partition_points(num_rows, &columns)?
+ } else {
+ vec![]
+ };
+ let mut row_wise_results: Vec<ScalarValue> = vec![];
+ for idx in state.last_calculated_index..num_rows {
+ state.window_frame_range = if self.expr.uses_window_frame() {
+ window_frame_ctx.calculate_range(
+ &order_bys,
+ &sort_options,
+ num_rows,
+ idx,
+ )
+ } else {
+ evaluator.get_range(state, num_rows)
+ }?;
+ evaluator.update_state(state, &order_bys,
&sort_partition_points)?;
+
+ // Exit if range end index is length, need kind of flag to stop
+ if state.window_frame_range.end == num_rows
+ && !partition_batch_state.is_end
+ {
+ state.window_frame_range = last_range.clone();
+ break;
+ }
+ let frame_range = &state.window_frame_range;
+ row_wise_results.push(if frame_range.start == frame_range.end {
+ // We produce None if the window is empty.
+ ScalarValue::try_from(out_type)
+ } else {
+ evaluator.evaluate_stateful(&values)
+ }?);
+ last_range = frame_range.clone();
+ state.last_calculated_index = idx + 1;
+ }
+ state.window_frame_range = last_range;
+ let out_col = if row_wise_results.is_empty() {
+ ScalarValue::try_from(out_type)?.to_array_of_size(0)
+ } else {
+ ScalarValue::iter_to_array(row_wise_results.into_iter())?
+ };
+
+ state.out_col = concat(&[&state.out_col, &out_col])?;
+ state.n_row_result_missing = num_rows -
state.last_calculated_index;
+ state.window_function_state =
+ WindowFunctionState::BuiltinWindowState(evaluator.state()?);
+ }
+ Ok(())
+ }
+
fn get_window_frame(&self) -> &Arc<WindowFrame> {
&self.window_frame
}
@@ -136,4 +237,13 @@ impl WindowExpr for BuiltInWindowExpr {
)) as _
})
}
+
+ fn uses_bounded_memory(&self) -> bool {
+ // NOTE: Currently, groups queries do not support the bounded memory
variant.
+ self.expr.supports_bounded_execution()
+ && (!self.expr.uses_window_frame()
+ || !(self.window_frame.start_bound.is_unbounded()
+ || self.window_frame.end_bound.is_unbounded()
+ || matches!(self.window_frame.units,
WindowFrameUnits::Groups)))
+ }
}
diff --git
a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs
b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs
index c358403fe..6f41ec599 100644
--- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs
+++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs
@@ -64,4 +64,12 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync +
std::fmt::Debug {
fn reverse_expr(&self) -> Option<Arc<dyn BuiltInWindowFunctionExpr>> {
None
}
+
+ fn supports_bounded_execution(&self) -> bool {
+ false
+ }
+
+ fn uses_window_frame(&self) -> bool {
+ false
+ }
}
diff --git a/datafusion/physical-expr/src/window/cume_dist.rs
b/datafusion/physical-expr/src/window/cume_dist.rs
index 45fe51178..3abb91e06 100644
--- a/datafusion/physical-expr/src/window/cume_dist.rs
+++ b/datafusion/physical-expr/src/window/cume_dist.rs
@@ -66,6 +66,7 @@ impl BuiltInWindowFunctionExpr for CumeDist {
}
}
+#[derive(Debug)]
pub(crate) struct CumeDistEvaluator;
impl PartitionEvaluator for CumeDistEvaluator {
diff --git a/datafusion/physical-expr/src/window/lead_lag.rs
b/datafusion/physical-expr/src/window/lead_lag.rs
index e18815c4c..fc815a220 100644
--- a/datafusion/physical-expr/src/window/lead_lag.rs
+++ b/datafusion/physical-expr/src/window/lead_lag.rs
@@ -19,7 +19,8 @@
//! at runtime during query execution
use crate::window::partition_evaluator::PartitionEvaluator;
-use crate::window::BuiltInWindowFunctionExpr;
+use crate::window::window_expr::{BuiltinWindowState, LeadLagState};
+use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
use crate::PhysicalExpr;
use arrow::array::ArrayRef;
use arrow::compute::cast;
@@ -27,7 +28,8 @@ use arrow::datatypes::{DataType, Field};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use std::any::Any;
-use std::ops::Neg;
+use std::cmp::min;
+use std::ops::{Neg, Range};
use std::sync::Arc;
/// window shift expression
@@ -102,11 +104,16 @@ impl BuiltInWindowFunctionExpr for WindowShift {
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(WindowShiftEvaluator {
+ state: LeadLagState { idx: 0 },
shift_offset: self.shift_offset,
default_value: self.default_value.clone(),
}))
}
+ fn supports_bounded_execution(&self) -> bool {
+ true
+ }
+
fn reverse_expr(&self) -> Option<Arc<dyn BuiltInWindowFunctionExpr>> {
Some(Arc::new(Self {
name: self.name.clone(),
@@ -118,7 +125,9 @@ impl BuiltInWindowFunctionExpr for WindowShift {
}
}
+#[derive(Debug)]
pub(crate) struct WindowShiftEvaluator {
+ state: LeadLagState,
shift_offset: i64,
default_value: Option<ScalarValue>,
}
@@ -173,6 +182,54 @@ fn shift_with_default_value(
}
impl PartitionEvaluator for WindowShiftEvaluator {
+ fn state(&self) -> Result<BuiltinWindowState> {
+ // If we do not use state we just return Default
+ Ok(BuiltinWindowState::LeadLag(self.state.clone()))
+ }
+
+ fn update_state(
+ &mut self,
+ state: &WindowAggState,
+ _range_columns: &[ArrayRef],
+ _sort_partition_points: &[Range<usize>],
+ ) -> Result<()> {
+ self.state.idx = state.last_calculated_index;
+ Ok(())
+ }
+
+ fn get_range(&self, state: &WindowAggState, n_rows: usize) ->
Result<Range<usize>> {
+ if self.shift_offset > 0 {
+ let offset = self.shift_offset as usize;
+ let start = if state.last_calculated_index > offset {
+ state.last_calculated_index - offset
+ } else {
+ 0
+ };
+ Ok(Range {
+ start,
+ end: state.last_calculated_index + 1,
+ })
+ } else {
+ let offset = (-self.shift_offset) as usize;
+ let end = min(state.last_calculated_index + offset, n_rows);
+ Ok(Range {
+ start: state.last_calculated_index,
+ end,
+ })
+ }
+ }
+
+ fn evaluate_stateful(&mut self, values: &[ArrayRef]) ->
Result<ScalarValue> {
+ let array = &values[0];
+ let dtype = array.data_type();
+ let idx = self.state.idx as i64 - self.shift_offset;
+ if idx < 0 || idx as usize >= array.len() {
+ get_default_value(&self.default_value, dtype)
+ } else {
+ ScalarValue::try_from_array(array, idx as usize)
+ }
+ }
+
fn evaluate(&self, values: &[ArrayRef], _num_rows: usize) ->
Result<ArrayRef> {
// LEAD, LAG window functions take single column, values will have
size 1
let value = &values[0];
@@ -180,6 +237,23 @@ impl PartitionEvaluator for WindowShiftEvaluator {
}
}
+fn get_default_value(
+ default_value: &Option<ScalarValue>,
+ dtype: &DataType,
+) -> Result<ScalarValue> {
+ if let Some(value) = default_value {
+ if let ScalarValue::Int64(Some(val)) = value {
+ ScalarValue::try_from_string(val.to_string(), dtype)
+ } else {
+ Err(DataFusionError::Internal(
+ "Expects default value to have Int64 type".to_string(),
+ ))
+ }
+ } else {
+ Ok(ScalarValue::try_from(dtype)?)
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/physical-expr/src/window/mod.rs
b/datafusion/physical-expr/src/window/mod.rs
index ffbce598e..35036a6db 100644
--- a/datafusion/physical-expr/src/window/mod.rs
+++ b/datafusion/physical-expr/src/window/mod.rs
@@ -33,4 +33,10 @@ pub use aggregate::AggregateWindowExpr;
pub use built_in::BuiltInWindowExpr;
pub use built_in_window_function_expr::BuiltInWindowFunctionExpr;
pub use sliding_aggregate::SlidingAggregateWindowExpr;
+pub use window_expr::PartitionBatchState;
+pub use window_expr::PartitionBatches;
+pub use window_expr::PartitionKey;
+pub use window_expr::PartitionWindowAggStates;
+pub use window_expr::WindowAggState;
pub use window_expr::WindowExpr;
+pub use window_expr::WindowState;
diff --git a/datafusion/physical-expr/src/window/nth_value.rs
b/datafusion/physical-expr/src/window/nth_value.rs
index e998b4701..c3c3b55d4 100644
--- a/datafusion/physical-expr/src/window/nth_value.rs
+++ b/datafusion/physical-expr/src/window/nth_value.rs
@@ -19,7 +19,8 @@
//! that can evaluated at runtime during query execution
use crate::window::partition_evaluator::PartitionEvaluator;
-use crate::window::BuiltInWindowFunctionExpr;
+use crate::window::window_expr::{BuiltinWindowState, NthValueState};
+use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
use crate::PhysicalExpr;
use arrow::array::{Array, ArrayRef};
use arrow::datatypes::{DataType, Field};
@@ -121,7 +122,18 @@ impl BuiltInWindowFunctionExpr for NthValue {
}
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
- Ok(Box::new(NthValueEvaluator { kind: self.kind }))
+ Ok(Box::new(NthValueEvaluator {
+ state: NthValueState::default(),
+ kind: self.kind,
+ }))
+ }
+
+ fn supports_bounded_execution(&self) -> bool {
+ true
+ }
+
+ fn uses_window_frame(&self) -> bool {
+ true
}
fn reverse_expr(&self) -> Option<Arc<dyn BuiltInWindowFunctionExpr>> {
@@ -140,13 +152,31 @@ impl BuiltInWindowFunctionExpr for NthValue {
}
/// Value evaluator for nth_value functions
+#[derive(Debug)]
pub(crate) struct NthValueEvaluator {
+ state: NthValueState,
kind: NthValueKind,
}
impl PartitionEvaluator for NthValueEvaluator {
- fn uses_window_frame(&self) -> bool {
- true
+ fn state(&self) -> Result<BuiltinWindowState> {
+ // If we do not use state we just return Default
+ Ok(BuiltinWindowState::NthValue(self.state.clone()))
+ }
+
+ fn update_state(
+ &mut self,
+ state: &WindowAggState,
+ _range_columns: &[ArrayRef],
+ _sort_partition_points: &[Range<usize>],
+ ) -> Result<()> {
+ // If we do not use state, update_state does nothing
+ self.state.range = state.window_frame_range.clone();
+ Ok(())
+ }
+
+ fn evaluate_stateful(&mut self, values: &[ArrayRef]) ->
Result<ScalarValue> {
+ self.evaluate_inside_range(values, self.state.range.clone())
}
fn evaluate_inside_range(
diff --git a/datafusion/physical-expr/src/window/ntile.rs
b/datafusion/physical-expr/src/window/ntile.rs
index f5844eccc..b8365dba1 100644
--- a/datafusion/physical-expr/src/window/ntile.rs
+++ b/datafusion/physical-expr/src/window/ntile.rs
@@ -64,6 +64,7 @@ impl BuiltInWindowFunctionExpr for Ntile {
}
}
+#[derive(Debug)]
pub(crate) struct NtileEvaluator {
n: u64,
}
diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs
b/datafusion/physical-expr/src/window/partition_evaluator.rs
index 86500441d..e6cead76d 100644
--- a/datafusion/physical-expr/src/window/partition_evaluator.rs
+++ b/datafusion/physical-expr/src/window/partition_evaluator.rs
@@ -17,26 +17,54 @@
//! partition evaluation module
+use crate::window::window_expr::BuiltinWindowState;
+use crate::window::WindowAggState;
use arrow::array::ArrayRef;
use datafusion_common::Result;
use datafusion_common::{DataFusionError, ScalarValue};
+use std::fmt::Debug;
use std::ops::Range;
/// Partition evaluator
-pub trait PartitionEvaluator {
+pub trait PartitionEvaluator: Debug + Send {
/// Whether the evaluator should be evaluated with rank
fn include_rank(&self) -> bool {
false
}
- fn uses_window_frame(&self) -> bool {
- false
+ /// Returns state of the Built-in Window Function
+ fn state(&self) -> Result<BuiltinWindowState> {
+ // If we do not use state we just return Default
+ Ok(BuiltinWindowState::Default)
+ }
+
+ fn update_state(
+ &mut self,
+ _state: &WindowAggState,
+ _range_columns: &[ArrayRef],
+ _sort_partition_points: &[Range<usize>],
+ ) -> Result<()> {
+ // If we do not use state, update_state does nothing
+ Ok(())
+ }
+
+ fn get_range(&self, _state: &WindowAggState, _n_rows: usize) ->
Result<Range<usize>> {
+ Err(DataFusionError::NotImplemented(
+ "get_range is not implemented for this window
function".to_string(),
+ ))
}
/// evaluate the partition evaluator against the partition
fn evaluate(&self, _values: &[ArrayRef], _num_rows: usize) ->
Result<ArrayRef> {
Err(DataFusionError::NotImplemented(
- "evaluate_partition is not implemented by default".into(),
+ "evaluate is not implemented by default".into(),
+ ))
+ }
+
+ /// evaluate window function result inside given range
+ fn evaluate_stateful(&mut self, _values: &[ArrayRef]) ->
Result<ScalarValue> {
+ Err(DataFusionError::NotImplemented(
+ "evaluate_stateful is not implemented by default".into(),
))
}
diff --git a/datafusion/physical-expr/src/window/rank.rs
b/datafusion/physical-expr/src/window/rank.rs
index 87e01528d..ead9d4453 100644
--- a/datafusion/physical-expr/src/window/rank.rs
+++ b/datafusion/physical-expr/src/window/rank.rs
@@ -19,12 +19,13 @@
//! at runtime during query execution
use crate::window::partition_evaluator::PartitionEvaluator;
-use crate::window::BuiltInWindowFunctionExpr;
+use crate::window::window_expr::{BuiltinWindowState, RankState};
+use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
use crate::PhysicalExpr;
use arrow::array::ArrayRef;
use arrow::array::{Float64Array, UInt64Array};
use arrow::datatypes::{DataType, Field};
-use datafusion_common::Result;
+use datafusion_common::{DataFusionError, Result, ScalarValue};
use std::any::Any;
use std::iter;
use std::ops::Range;
@@ -98,18 +99,77 @@ impl BuiltInWindowFunctionExpr for Rank {
&self.name
}
+ fn supports_bounded_execution(&self) -> bool {
+ matches!(self.rank_type, RankType::Basic | RankType::Dense)
+ }
+
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(RankEvaluator {
+ state: RankState::default(),
rank_type: self.rank_type,
}))
}
}
+#[derive(Debug)]
pub(crate) struct RankEvaluator {
+ state: RankState,
rank_type: RankType,
}
impl PartitionEvaluator for RankEvaluator {
+ fn get_range(&self, state: &WindowAggState, _n_rows: usize) ->
Result<Range<usize>> {
+ Ok(Range {
+ start: state.last_calculated_index,
+ end: state.last_calculated_index + 1,
+ })
+ }
+
+ fn state(&self) -> Result<BuiltinWindowState> {
+ Ok(BuiltinWindowState::Rank(self.state.clone()))
+ }
+
+ fn update_state(
+ &mut self,
+ state: &WindowAggState,
+ range_columns: &[ArrayRef],
+ sort_partition_points: &[Range<usize>],
+ ) -> Result<()> {
+ // find range inside `sort_partition_points` containing
`state.last_calculated_index`
+ let chunk_idx = sort_partition_points
+ .iter()
+ .position(|elem| {
+ elem.start <= state.last_calculated_index
+ && state.last_calculated_index < elem.end
+ })
+ .ok_or_else(|| DataFusionError::Execution("Expects
sort_partition_points to contain state.last_calculated_index".to_string()))?;
+ let chunk = &sort_partition_points[chunk_idx];
+ let last_rank_data = range_columns
+ .iter()
+ .map(|c| ScalarValue::try_from_array(c, chunk.end - 1))
+ .collect::<Result<Vec<_>>>()?;
+ let empty = self.state.last_rank_data.is_empty();
+ if empty || self.state.last_rank_data != last_rank_data {
+ self.state.last_rank_data = last_rank_data;
+ self.state.last_rank_boundary = state.offset_pruned_rows +
chunk.start;
+ self.state.n_rank = 1 + if empty { chunk_idx } else {
self.state.n_rank };
+ }
+ Ok(())
+ }
+
+ /// evaluate window function result inside given range
+ fn evaluate_stateful(&mut self, _values: &[ArrayRef]) ->
Result<ScalarValue> {
+ match self.rank_type {
+ RankType::Basic => Ok(ScalarValue::UInt64(Some(
+ self.state.last_rank_boundary as u64 + 1,
+ ))),
+ RankType::Dense => Ok(ScalarValue::UInt64(Some(self.state.n_rank
as u64))),
+ RankType::Percent => Err(DataFusionError::Execution(
+ "Can not execute PERCENT_RANK in a streaming
fashion".to_string(),
+ )),
+ }
+ }
+
fn include_rank(&self) -> bool {
true
}
diff --git a/datafusion/physical-expr/src/window/row_number.rs
b/datafusion/physical-expr/src/window/row_number.rs
index b27ac29d2..c858a5724 100644
--- a/datafusion/physical-expr/src/window/row_number.rs
+++ b/datafusion/physical-expr/src/window/row_number.rs
@@ -18,12 +18,14 @@
//! Defines physical expression for `row_number` that can evaluated at runtime
during query execution
use crate::window::partition_evaluator::PartitionEvaluator;
-use crate::window::BuiltInWindowFunctionExpr;
+use crate::window::window_expr::{BuiltinWindowState, NumRowsState};
+use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
use crate::PhysicalExpr;
use arrow::array::{ArrayRef, UInt64Array};
use arrow::datatypes::{DataType, Field};
-use datafusion_common::Result;
+use datafusion_common::{Result, ScalarValue};
use std::any::Any;
+use std::ops::Range;
use std::sync::Arc;
/// row_number expression
@@ -62,12 +64,36 @@ impl BuiltInWindowFunctionExpr for RowNumber {
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::<NumRowsEvaluator>::default())
}
+
+ fn supports_bounded_execution(&self) -> bool {
+ true
+ }
}
-#[derive(Default)]
-pub(crate) struct NumRowsEvaluator {}
+#[derive(Default, Debug)]
+pub(crate) struct NumRowsEvaluator {
+ state: NumRowsState,
+}
impl PartitionEvaluator for NumRowsEvaluator {
+ fn state(&self) -> Result<BuiltinWindowState> {
+ // If we do not use state we just return Default
+ Ok(BuiltinWindowState::NumRows(self.state.clone()))
+ }
+
+ fn get_range(&self, state: &WindowAggState, _n_rows: usize) ->
Result<Range<usize>> {
+ Ok(Range {
+ start: state.last_calculated_index,
+ end: state.last_calculated_index + 1,
+ })
+ }
+
+ /// evaluate window function result inside given range
+ fn evaluate_stateful(&mut self, _values: &[ArrayRef]) ->
Result<ScalarValue> {
+ self.state.n_rows += 1;
+ Ok(ScalarValue::UInt64(Some(self.state.n_rows as u64)))
+ }
+
fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) ->
Result<ArrayRef> {
Ok(Arc::new(UInt64Array::from_iter_values(
1..(num_rows as u64) + 1,
diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs
b/datafusion/physical-expr/src/window/sliding_aggregate.rs
index 2a0fa86b7..587c313e3 100644
--- a/datafusion/physical-expr/src/window/sliding_aggregate.rs
+++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs
@@ -23,16 +23,19 @@ use std::ops::Range;
use std::sync::Arc;
use arrow::array::Array;
-use arrow::compute::SortOptions;
+use arrow::compute::{concat, SortOptions};
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};
-use datafusion_common::Result;
use datafusion_common::ScalarValue;
-use datafusion_expr::WindowFrame;
+use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::{Accumulator, WindowFrame, WindowFrameUnits};
-use crate::window::window_expr::reverse_order_bys;
-use crate::window::AggregateWindowExpr;
+use crate::window::window_expr::{reverse_order_bys, WindowFn,
WindowFunctionState};
+use crate::window::{
+ AggregateWindowExpr, PartitionBatches, PartitionWindowAggStates,
WindowAggState,
+ WindowState,
+};
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
use crate::{window::WindowExpr, AggregateExpr};
@@ -92,50 +95,75 @@ impl WindowExpr for SlidingAggregateWindowExpr {
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- let sort_options: Vec<SortOptions> =
- self.order_by.iter().map(|o| o.options).collect();
- let mut row_wise_results: Vec<ScalarValue> = vec![];
-
let mut accumulator = self.aggregate.create_sliding_accumulator()?;
- let length = batch.num_rows();
- let (values, order_bys) = self.get_values_orderbys(batch)?;
let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
let mut last_range = Range { start: 0, end: 0 };
+ let mut idx = 0;
+ self.get_result_column(
+ &mut accumulator,
+ batch,
+ &mut window_frame_ctx,
+ &mut last_range,
+ &mut idx,
+ true,
+ )
+ }
- // We iterate on each row to perform a running calculation.
- // First, cur_range is calculated, then it is compared with last_range.
- for i in 0..length {
- let cur_range =
- window_frame_ctx.calculate_range(&order_bys, &sort_options,
length, i)?;
- let value = if cur_range.start == cur_range.end {
- // We produce None if the window is empty.
- ScalarValue::try_from(self.aggregate.field()?.data_type())?
- } else {
- // Accumulate any new rows that have entered the window:
- let update_bound = cur_range.end - last_range.end;
- if update_bound > 0 {
- let update: Vec<ArrayRef> = values
- .iter()
- .map(|v| v.slice(last_range.end, update_bound))
- .collect();
- accumulator.update_batch(&update)?
- }
- // Remove rows that have now left the window:
- let retract_bound = cur_range.start - last_range.start;
- if retract_bound > 0 {
- let retract: Vec<ArrayRef> = values
- .iter()
- .map(|v| v.slice(last_range.start, retract_bound))
- .collect();
- accumulator.retract_batch(&retract)?
- }
- accumulator.evaluate()?
+ fn evaluate_stateful(
+ &self,
+ partition_batches: &PartitionBatches,
+ window_agg_state: &mut PartitionWindowAggStates,
+ ) -> Result<()> {
+ let field = self.aggregate.field()?;
+ let out_type = field.data_type();
+ for (partition_row, partition_batch_state) in partition_batches.iter()
{
+ if !window_agg_state.contains_key(partition_row) {
+ let accumulator = self.aggregate.create_sliding_accumulator()?;
+ window_agg_state.insert(
+ partition_row.clone(),
+ WindowState {
+ state: WindowAggState::new(
+ out_type,
+ WindowFunctionState::AggregateState(vec![]),
+ )?,
+ window_fn: WindowFn::Aggregate(accumulator),
+ },
+ );
};
- row_wise_results.push(value);
- last_range = cur_range;
+ let window_state =
+ window_agg_state.get_mut(partition_row).ok_or_else(|| {
+ DataFusionError::Execution("Cannot find state".to_string())
+ })?;
+ let accumulator = match &mut window_state.window_fn {
+ WindowFn::Aggregate(accumulator) => accumulator,
+ _ => unreachable!(),
+ };
+ let mut state = &mut window_state.state;
+ state.is_end = partition_batch_state.is_end;
+
+ let mut idx = state.last_calculated_index;
+ let mut last_range = state.window_frame_range.clone();
+ let mut window_frame_ctx =
WindowFrameContext::new(&self.window_frame);
+ let out_col = self.get_result_column(
+ accumulator,
+ &partition_batch_state.record_batch,
+ &mut window_frame_ctx,
+ &mut last_range,
+ &mut idx,
+ state.is_end,
+ )?;
+ state.last_calculated_index = idx;
+ state.window_frame_range = last_range.clone();
+
+ state.out_col = concat(&[&state.out_col, &out_col])?;
+ let num_rows = partition_batch_state.record_batch.num_rows();
+ state.n_row_result_missing = num_rows -
state.last_calculated_index;
+
+ state.window_function_state =
+ WindowFunctionState::AggregateState(accumulator.state()?);
}
- ScalarValue::iter_to_array(row_wise_results.into_iter())
+ Ok(())
}
fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
@@ -170,4 +198,96 @@ impl WindowExpr for SlidingAggregateWindowExpr {
}
})
}
+
+ fn uses_bounded_memory(&self) -> bool {
+ // NOTE: Currently, groups queries do not support the bounded memory
variant.
+ self.aggregate.supports_bounded_execution()
+ && !self.window_frame.start_bound.is_unbounded()
+ && !self.window_frame.end_bound.is_unbounded()
+ && !matches!(self.window_frame.units, WindowFrameUnits::Groups)
+ }
+}
+
+impl SlidingAggregateWindowExpr {
+ /// For given range calculate accumulator result inside range on
value_slice and
+ /// update accumulator state
+ fn get_aggregate_result_inside_range(
+ &self,
+ last_range: &Range<usize>,
+ cur_range: &Range<usize>,
+ value_slice: &[ArrayRef],
+ accumulator: &mut Box<dyn Accumulator>,
+ ) -> Result<ScalarValue> {
+ let value = if cur_range.start == cur_range.end {
+ // We produce None if the window is empty.
+ ScalarValue::try_from(self.aggregate.field()?.data_type())?
+ } else {
+ // Accumulate any new rows that have entered the window:
+ let update_bound = cur_range.end - last_range.end;
+ if update_bound > 0 {
+ let update: Vec<ArrayRef> = value_slice
+ .iter()
+ .map(|v| v.slice(last_range.end, update_bound))
+ .collect();
+ accumulator.update_batch(&update)?
+ }
+ // Remove rows that have now left the window:
+ let retract_bound = cur_range.start - last_range.start;
+ if retract_bound > 0 {
+ let retract: Vec<ArrayRef> = value_slice
+ .iter()
+ .map(|v| v.slice(last_range.start, retract_bound))
+ .collect();
+ accumulator.retract_batch(&retract)?
+ }
+ accumulator.evaluate()?
+ };
+ Ok(value)
+ }
+
+ fn get_result_column(
+ &self,
+ accumulator: &mut Box<dyn Accumulator>,
+ record_batch: &RecordBatch,
+ window_frame_ctx: &mut WindowFrameContext,
+ last_range: &mut Range<usize>,
+ idx: &mut usize,
+ is_end: bool,
+ ) -> Result<ArrayRef> {
+ let (values, order_bys) = self.get_values_orderbys(record_batch)?;
+ // We iterate on each row to perform a running calculation.
+ let length = values[0].len();
+ let sort_options: Vec<SortOptions> =
+ self.order_by.iter().map(|o| o.options).collect();
+ let mut row_wise_results: Vec<ScalarValue> = vec![];
+ let field = self.aggregate.field()?;
+ let out_type = field.data_type();
+ while *idx < length {
+ let cur_range = window_frame_ctx.calculate_range(
+ &order_bys,
+ &sort_options,
+ length,
+ *idx,
+ )?;
+ // Exit if range end index is length, need kind of flag to stop
+ if cur_range.end == length && !is_end {
+ break;
+ }
+ let value = self.get_aggregate_result_inside_range(
+ last_range,
+ &cur_range,
+ &values,
+ accumulator,
+ )?;
+ row_wise_results.push(value);
+ last_range.start = cur_range.start;
+ last_range.end = cur_range.end;
+ *idx += 1;
+ }
+ Ok(if row_wise_results.is_empty() {
+ ScalarValue::try_from(out_type)?.to_array_of_size(0)
+ } else {
+ ScalarValue::iter_to_array(row_wise_results.into_iter())?
+ })
+ }
}
diff --git a/datafusion/physical-expr/src/window/window_expr.rs
b/datafusion/physical-expr/src/window/window_expr.rs
index a718fa4cd..656b6723b 100644
--- a/datafusion/physical-expr/src/window/window_expr.rs
+++ b/datafusion/physical-expr/src/window/window_expr.rs
@@ -15,13 +15,16 @@
// specific language governing permissions and limitations
// under the License.
+use crate::window::partition_evaluator::PartitionEvaluator;
use crate::{PhysicalExpr, PhysicalSortExpr};
use arrow::compute::kernels::partition::lexicographical_partition_ranges;
use arrow::compute::kernels::sort::SortColumn;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};
-use datafusion_common::{reverse_sort_options, DataFusionError, Result};
-use datafusion_expr::WindowFrame;
+use arrow_schema::DataType;
+use datafusion_common::{reverse_sort_options, DataFusionError, Result,
ScalarValue};
+use datafusion_expr::{Accumulator, WindowFrame};
+use indexmap::IndexMap;
use std::any::Any;
use std::fmt::Debug;
use std::ops::Range;
@@ -61,6 +64,18 @@ pub trait WindowExpr: Send + Sync + Debug {
/// evaluate the window function values against the batch
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
+ /// evaluate the window function values against the batch
+ fn evaluate_stateful(
+ &self,
+ _partition_batches: &PartitionBatches,
+ _window_agg_state: &mut PartitionWindowAggStates,
+ ) -> Result<()> {
+ Err(DataFusionError::Internal(format!(
+ "evaluate_stateful is not implemented for {}",
+ self.name()
+ )))
+ }
+
/// evaluate the partition points given the sort columns; if the sort
columns are
/// empty then the result will be a single element vec of the whole column
rows.
fn evaluate_partition_points(
@@ -116,6 +131,10 @@ pub trait WindowExpr: Send + Sync + Debug {
/// Get the window frame of this [WindowExpr].
fn get_window_frame(&self) -> &Arc<WindowFrame>;
+ /// Return a flag indicating whether this [WindowExpr] can run with
+ /// bounded memory.
+ fn uses_bounded_memory(&self) -> bool;
+
/// Get the reverse expression of this [WindowExpr].
fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;
}
@@ -132,3 +151,118 @@ pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr])
-> Vec<PhysicalSortExpr
})
.collect()
}
+
+#[derive(Debug)]
+pub enum WindowFn {
+ Builtin(Box<dyn PartitionEvaluator>),
+ Aggregate(Box<dyn Accumulator>),
+}
+
+/// State for RANK(percent_rank, rank, dense_rank)
+/// builtin window function
+#[derive(Debug, Clone, Default)]
+pub struct RankState {
+ /// The last values for rank as these values change, we increase n_rank
+ pub last_rank_data: Vec<ScalarValue>,
+ /// The index where last_rank_boundary is started
+ pub last_rank_boundary: usize,
+ /// Rank number kept from the start
+ pub n_rank: usize,
+}
+
+/// State for 'ROW_NUMBER' builtin window function
+#[derive(Debug, Clone, Default)]
+pub struct NumRowsState {
+ pub n_rows: usize,
+}
+
+#[derive(Debug, Clone, Default)]
+pub struct NthValueState {
+ pub range: Range<usize>,
+}
+
+#[derive(Debug, Clone, Default)]
+pub struct LeadLagState {
+ pub idx: usize,
+}
+
+#[derive(Debug, Clone, Default)]
+pub enum BuiltinWindowState {
+ Rank(RankState),
+ NumRows(NumRowsState),
+ NthValue(NthValueState),
+ LeadLag(LeadLagState),
+ #[default]
+ Default,
+}
+#[derive(Debug)]
+pub enum WindowFunctionState {
+ /// Different Aggregate functions may have different state definitions
+ /// In [Accumulator] trait, [fn state(&self) -> Result<Vec<ScalarValue>>]
implementation
+ /// dictates that.
+ AggregateState(Vec<ScalarValue>),
+ /// BuiltinWindowState
+ BuiltinWindowState(BuiltinWindowState),
+}
+
+#[derive(Debug)]
+pub struct WindowAggState {
+ /// The range that we calculate the window function
+ pub window_frame_range: Range<usize>,
+ /// The index of the last row that its result is calculated inside the
partition record batch buffer.
+ pub last_calculated_index: usize,
+ /// The offset of the deleted row number
+ pub offset_pruned_rows: usize,
+ /// State of the window function, required to calculate its result
+ // For instance, for ROW_NUMBER we keep the row index counter to generate
correct result
+ pub window_function_state: WindowFunctionState,
+ /// Stores the results calculated by window frame
+ pub out_col: ArrayRef,
+ /// Keeps track of how many rows should be generated to be in sync with
input record_batch.
+ // (For each row in the input record batch we need to generate a window
result).
+ pub n_row_result_missing: usize,
+ /// flag indicating whether we have received all data for this partition
+ pub is_end: bool,
+}
+
+/// State for each unique partition determined according to PARTITION BY
column(s)
+#[derive(Debug)]
+pub struct PartitionBatchState {
+ /// The record_batch belonging to current partition
+ pub record_batch: RecordBatch,
+ /// flag indicating whether we have received all data for this partition
+ pub is_end: bool,
+}
+
+/// key for IndexMap for each unique partition
+/// For instance, if window frame is OVER(PARTITION BY a,b)
+/// PartitionKey would consist of unique [a,b] pairs
+pub type PartitionKey = Vec<ScalarValue>;
+
+#[derive(Debug)]
+pub struct WindowState {
+ pub state: WindowAggState,
+ pub window_fn: WindowFn,
+}
+pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>;
+
+/// The IndexMap (i.e. an ordered HashMap) where record batches are separated
for each partition.
+pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>;
+
+impl WindowAggState {
+ pub fn new(
+ out_type: &DataType,
+ window_function_state: WindowFunctionState,
+ ) -> Result<Self> {
+ let empty_out_col =
ScalarValue::try_from(out_type)?.to_array_of_size(0);
+ Ok(Self {
+ window_frame_range: Range { start: 0, end: 0 },
+ last_calculated_index: 0,
+ offset_pruned_rows: 0,
+ window_function_state,
+ out_col: empty_out_col,
+ n_row_result_missing: 0,
+ is_end: false,
+ })
+ }
+}
diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs
index 4002a49cf..dfd878275 100644
--- a/test-utils/src/lib.rs
+++ b/test-utils/src/lib.rs
@@ -50,7 +50,10 @@ pub fn partitions_to_sorted_vec(partitions:
&[Vec<RecordBatch>]) -> Vec<Option<i
}
/// Adds a random number of empty record batches into the stream
-fn add_empty_batches(batches: Vec<RecordBatch>, rng: &mut StdRng) ->
Vec<RecordBatch> {
+pub fn add_empty_batches(
+ batches: Vec<RecordBatch>,
+ rng: &mut StdRng,
+) -> Vec<RecordBatch> {
let schema = batches[0].schema();
batches