This is an automated email from the ASF dual-hosted git repository.
yjshen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new dafe99733e feat: Support SQL filter clause for aggregate expressions,
add SQL dialect support (#5868)
dafe99733e is described below
commit dafe99733e0f97bfb5ef750f02d02abcb641682d
Author: Yijie Shen <[email protected]>
AuthorDate: Wed Apr 12 02:09:23 2023 +0800
feat: Support SQL filter clause for aggregate expressions, add SQL dialect
support (#5868)
---
datafusion/common/src/config.rs | 4 +
datafusion/core/src/execution/context.rs | 35 +++-
.../src/physical_optimizer/aggregate_statistics.rs | 13 ++
.../src/physical_optimizer/dist_enforcement.rs | 8 +
.../core/src/physical_optimizer/repartition.rs | 2 +
.../src/physical_optimizer/sort_enforcement.rs | 1 +
.../core/src/physical_plan/aggregates/mod.rs | 33 ++++
.../src/physical_plan/aggregates/no_grouping.rs | 35 +++-
.../core/src/physical_plan/aggregates/row_hash.rs | 211 +++++++++++++--------
datafusion/core/src/physical_plan/filter.rs | 2 +-
datafusion/core/src/physical_plan/planner.rs | 57 ++++--
.../tests/sqllogictests/test_files/aggregate.slt | 80 ++++++++
.../test_files/information_schema.slt | 1 +
datafusion/expr/src/tree_node/expr.rs | 4 +-
datafusion/optimizer/src/push_down_projection.rs | 2 +-
datafusion/proto/proto/datafusion.proto | 5 +
datafusion/proto/src/generated/pbjson.rs | 109 +++++++++++
datafusion/proto/src/generated/prost.rs | 8 +
datafusion/proto/src/physical_plan/mod.rs | 22 +++
datafusion/proto/src/physical_plan/to_proto.rs | 13 ++
datafusion/sql/tests/integration_test.rs | 4 +-
docs/source/user-guide/configs.md | 1 +
22 files changed, 548 insertions(+), 102 deletions(-)
diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs
index 55cdc36d20..5973bf262e 100644
--- a/datafusion/common/src/config.rs
+++ b/datafusion/common/src/config.rs
@@ -187,6 +187,10 @@ config_namespace! {
/// When set to true, SQL parser will normalize ident (convert ident
to lowercase when not quoted)
pub enable_ident_normalization: bool, default = true
+ /// Configure the SQL dialect used by DataFusion's parser; supported
values include: Generic,
+ /// MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL,
ClickHouse, BigQuery, and Ansi.
+ pub dialect: String, default = "generic".to_string()
+
}
}
diff --git a/datafusion/core/src/execution/context.rs
b/datafusion/core/src/execution/context.rs
index c3adb4cc74..5114adec72 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -91,6 +91,11 @@ use datafusion_sql::{
planner::{ContextProvider, SqlToRel},
};
use parquet::file::properties::WriterProperties;
+use sqlparser::dialect::{
+ AnsiDialect, BigQueryDialect, ClickHouseDialect, Dialect, GenericDialect,
+ HiveDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect,
RedshiftSqlDialect,
+ SQLiteDialect, SnowflakeDialect,
+};
use url::Url;
use crate::catalog::information_schema::{InformationSchemaProvider,
INFORMATION_SCHEMA};
@@ -1500,8 +1505,10 @@ impl SessionState {
pub fn sql_to_statement(
&self,
sql: &str,
+ dialect: &str,
) -> Result<datafusion_sql::parser::Statement> {
- let mut statements = DFParser::parse_sql(sql)?;
+ let dialect = create_dialect_from_str(dialect)?;
+ let mut statements = DFParser::parse_sql_with_dialect(sql,
dialect.as_ref())?;
if statements.len() > 1 {
return Err(DataFusionError::NotImplemented(
"The context currently only supports a single SQL
statement".to_string(),
@@ -1629,7 +1636,8 @@ impl SessionState {
///
/// See [`SessionContext::sql`] for a higher-level interface that also
handles DDL
pub async fn create_logical_plan(&self, sql: &str) -> Result<LogicalPlan> {
- let statement = self.sql_to_statement(sql)?;
+ let dialect = self.config.options().sql_parser.dialect.as_str();
+ let statement = self.sql_to_statement(sql, dialect)?;
let plan = self.statement_to_plan(statement).await?;
Ok(plan)
}
@@ -1838,6 +1846,29 @@ impl From<&SessionState> for TaskContext {
}
}
+// TODO: remove when https://github.com/sqlparser-rs/sqlparser-rs/pull/848 is
released
+fn create_dialect_from_str(dialect_name: &str) -> Result<Box<dyn Dialect>> {
+ match dialect_name.to_lowercase().as_str() {
+ "generic" => Ok(Box::new(GenericDialect)),
+ "mysql" => Ok(Box::new(MySqlDialect {})),
+ "postgresql" | "postgres" => Ok(Box::new(PostgreSqlDialect {})),
+ "hive" => Ok(Box::new(HiveDialect {})),
+ "sqlite" => Ok(Box::new(SQLiteDialect {})),
+ "snowflake" => Ok(Box::new(SnowflakeDialect)),
+ "redshift" => Ok(Box::new(RedshiftSqlDialect {})),
+ "mssql" => Ok(Box::new(MsSqlDialect {})),
+ "clickhouse" => Ok(Box::new(ClickHouseDialect {})),
+ "bigquery" => Ok(Box::new(BigQueryDialect)),
+ "ansi" => Ok(Box::new(AnsiDialect {})),
+ _ => {
+ Err(DataFusionError::Internal(format!(
+ "Unsupported SQL dialect: {}. Available dialects: Generic,
MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse,
BigQuery, Ansi.",
+ dialect_name
+ )))
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
index 59806a0a2f..b88f73d8c2 100644
--- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
+++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
@@ -123,6 +123,7 @@ fn take_optimizable(node: &dyn ExecutionPlan) ->
Option<Arc<dyn ExecutionPlan>>
{
if partial_agg_exec.mode() == &AggregateMode::Partial
&& partial_agg_exec.group_expr().is_empty()
+ && partial_agg_exec.filter_expr().iter().all(|e|
e.is_none())
{
let stats = partial_agg_exec.input().statistics();
if stats.is_exact {
@@ -410,6 +411,7 @@ mod tests {
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
+ vec![None],
source,
Arc::clone(&schema),
)?;
@@ -418,6 +420,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
+ vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
@@ -438,6 +441,7 @@ mod tests {
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
+ vec![None],
source,
Arc::clone(&schema),
)?;
@@ -446,6 +450,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
+ vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
@@ -465,6 +470,7 @@ mod tests {
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
+ vec![None],
source,
Arc::clone(&schema),
)?;
@@ -476,6 +482,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
+ vec![None],
Arc::new(coalesce),
Arc::clone(&schema),
)?;
@@ -495,6 +502,7 @@ mod tests {
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
+ vec![None],
source,
Arc::clone(&schema),
)?;
@@ -506,6 +514,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
+ vec![None],
Arc::new(coalesce),
Arc::clone(&schema),
)?;
@@ -536,6 +545,7 @@ mod tests {
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
+ vec![None],
filter,
Arc::clone(&schema),
)?;
@@ -544,6 +554,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
+ vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
@@ -579,6 +590,7 @@ mod tests {
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
+ vec![None],
filter,
Arc::clone(&schema),
)?;
@@ -587,6 +599,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
+ vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs
b/datafusion/core/src/physical_optimizer/dist_enforcement.rs
index d3e99945e9..affe432830 100644
--- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs
+++ b/datafusion/core/src/physical_optimizer/dist_enforcement.rs
@@ -252,6 +252,7 @@ fn adjust_input_keys_ordering(
mode,
group_by,
aggr_expr,
+ filter_expr,
input,
input_schema,
..
@@ -264,6 +265,7 @@ fn adjust_input_keys_ordering(
&parent_required,
group_by,
aggr_expr,
+ filter_expr,
input.clone(),
input_schema,
)?),
@@ -369,6 +371,7 @@ fn reorder_aggregate_keys(
parent_required: &[Arc<dyn PhysicalExpr>],
group_by: &PhysicalGroupBy,
aggr_expr: &[Arc<dyn AggregateExpr>],
+ filter_expr: &[Option<Arc<dyn PhysicalExpr>>],
agg_input: Arc<dyn ExecutionPlan>,
input_schema: &SchemaRef,
) -> Result<PlanWithKeyRequirements> {
@@ -398,6 +401,7 @@ fn reorder_aggregate_keys(
mode,
group_by,
aggr_expr,
+ filter_expr,
input,
input_schema,
..
@@ -416,6 +420,7 @@ fn reorder_aggregate_keys(
AggregateMode::Partial,
new_partial_group_by,
aggr_expr.clone(),
+ filter_expr.clone(),
input.clone(),
input_schema.clone(),
)?))
@@ -446,6 +451,7 @@ fn reorder_aggregate_keys(
AggregateMode::FinalPartitioned,
new_group_by,
aggr_expr.to_vec(),
+ filter_expr.to_vec(),
partial_agg,
input_schema.clone(),
)?);
@@ -1067,11 +1073,13 @@ mod tests {
AggregateMode::FinalPartitioned,
final_grouping,
vec![],
+ vec![],
Arc::new(
AggregateExec::try_new(
AggregateMode::Partial,
group_by,
vec![],
+ vec![],
input,
schema.clone(),
)
diff --git a/datafusion/core/src/physical_optimizer/repartition.rs
b/datafusion/core/src/physical_optimizer/repartition.rs
index 3bb21b12be..1db61e379e 100644
--- a/datafusion/core/src/physical_optimizer/repartition.rs
+++ b/datafusion/core/src/physical_optimizer/repartition.rs
@@ -477,11 +477,13 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![],
+ vec![],
Arc::new(
AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![],
+ vec![],
input,
schema.clone(),
)
diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs
b/datafusion/core/src/physical_optimizer/sort_enforcement.rs
index b1a4da65e0..bada74193b 100644
--- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs
+++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs
@@ -2469,6 +2469,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![],
+ vec![],
input,
schema,
)
diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs
b/datafusion/core/src/physical_plan/aggregates/mod.rs
index ade0fa0066..3cc8fd5d7d 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -171,6 +171,8 @@ pub struct AggregateExec {
pub(crate) group_by: PhysicalGroupBy,
/// Aggregate expressions
pub(crate) aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+ /// FILTER (WHERE clause) expression for each aggregate expression
+ pub(crate) filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
/// Input plan, could be a partial aggregate or the input to the aggregate
pub(crate) input: Arc<dyn ExecutionPlan>,
/// Schema after the aggregate is applied
@@ -192,6 +194,7 @@ impl AggregateExec {
mode: AggregateMode,
group_by: PhysicalGroupBy,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+ filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
) -> Result<Self> {
@@ -221,6 +224,7 @@ impl AggregateExec {
mode,
group_by,
aggr_expr,
+ filter_expr,
input,
schema,
input_schema,
@@ -258,6 +262,11 @@ impl AggregateExec {
&self.aggr_expr
}
+ /// FILTER (WHERE clause) expression for each aggregate expression
+ pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
+ &self.filter_expr
+ }
+
/// Input plan
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
@@ -281,6 +290,7 @@ impl AggregateExec {
self.mode,
self.schema.clone(),
self.aggr_expr.clone(),
+ self.filter_expr.clone(),
input,
baseline_metrics,
context,
@@ -293,6 +303,7 @@ impl AggregateExec {
self.schema.clone(),
self.group_by.clone(),
self.aggr_expr.clone(),
+ self.filter_expr.clone(),
input,
baseline_metrics,
batch_size,
@@ -391,6 +402,7 @@ impl ExecutionPlan for AggregateExec {
self.mode,
self.group_by.clone(),
self.aggr_expr.clone(),
+ self.filter_expr.clone(),
children[0].clone(),
self.input_schema.clone(),
)?))
@@ -703,6 +715,20 @@ fn evaluate_many(
.collect::<Result<Vec<_>>>()
}
+fn evaluate_optional(
+ expr: &[Option<Arc<dyn PhysicalExpr>>],
+ batch: &RecordBatch,
+) -> Result<Vec<Option<ArrayRef>>> {
+ expr.iter()
+ .map(|expr| {
+ expr.as_ref()
+ .map(|expr| expr.evaluate(batch))
+ .transpose()
+ .map(|r| r.map(|v| v.into_array(batch.num_rows())))
+ })
+ .collect::<Result<Vec<_>>>()
+}
+
fn evaluate_group_by(
group_by: &PhysicalGroupBy,
batch: &RecordBatch,
@@ -839,6 +865,7 @@ mod tests {
AggregateMode::Partial,
grouping_set.clone(),
aggregates.clone(),
+ vec![None],
input,
input_schema.clone(),
)?);
@@ -881,6 +908,7 @@ mod tests {
AggregateMode::Final,
final_grouping_set,
aggregates,
+ vec![None],
merge,
input_schema,
)?);
@@ -944,6 +972,7 @@ mod tests {
AggregateMode::Partial,
grouping_set.clone(),
aggregates.clone(),
+ vec![None],
input,
input_schema.clone(),
)?);
@@ -976,6 +1005,7 @@ mod tests {
AggregateMode::Final,
final_grouping_set,
aggregates,
+ vec![None],
merge,
input_schema,
)?);
@@ -1191,6 +1221,7 @@ mod tests {
AggregateMode::Partial,
groups,
aggregates,
+ vec![None; 3],
input.clone(),
input_schema.clone(),
)?);
@@ -1246,6 +1277,7 @@ mod tests {
AggregateMode::Partial,
groups.clone(),
aggregates.clone(),
+ vec![None],
blocking_exec,
schema,
)?);
@@ -1284,6 +1316,7 @@ mod tests {
AggregateMode::Partial,
groups,
aggregates.clone(),
+ vec![None],
blocking_exec,
schema,
)?);
diff --git a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
index c13f005b03..efeae8716d 100644
--- a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
+++ b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
@@ -29,10 +29,12 @@ use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};
use futures::stream::BoxStream;
+use std::borrow::Cow;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
+use crate::physical_plan::filter::batch_filter;
use futures::stream::{Stream, StreamExt};
/// stream struct for aggregation without grouping columns
@@ -52,23 +54,32 @@ struct AggregateStreamInner {
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+ filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
accumulators: Vec<AccumulatorItem>,
reservation: MemoryReservation,
finished: bool,
}
impl AggregateStream {
+ #[allow(clippy::too_many_arguments)]
/// Create a new AggregateStream
pub fn new(
mode: AggregateMode,
schema: SchemaRef,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+ filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
context: Arc<TaskContext>,
partition: usize,
) -> Result<Self> {
let aggregate_expressions = aggregate_expressions(&aggr_expr, &mode,
0)?;
+ let filter_expressions = match mode {
+ AggregateMode::Partial => filter_expr,
+ AggregateMode::Final | AggregateMode::FinalPartitioned => {
+ vec![None; aggr_expr.len()]
+ }
+ };
let accumulators = create_accumulators(&aggr_expr)?;
let reservation =
MemoryConsumer::new(format!("AggregateStream[{partition}]"))
@@ -80,6 +91,7 @@ impl AggregateStream {
input,
baseline_metrics,
aggregate_expressions,
+ filter_expressions,
accumulators,
reservation,
finished: false,
@@ -97,9 +109,10 @@ impl AggregateStream {
let timer = elapsed_compute.timer();
let result = aggregate_batch(
&this.mode,
- &batch,
+ batch,
&mut this.accumulators,
&this.aggregate_expressions,
+ &this.filter_expressions,
);
timer.done();
@@ -169,29 +182,37 @@ impl RecordBatchStream for AggregateStream {
/// TODO: Make this a member function
fn aggregate_batch(
mode: &AggregateMode,
- batch: &RecordBatch,
+ batch: RecordBatch,
accumulators: &mut [AccumulatorItem],
expressions: &[Vec<Arc<dyn PhysicalExpr>>],
+ filters: &[Option<Arc<dyn PhysicalExpr>>],
) -> Result<usize> {
let mut allocated = 0usize;
// 1.1 iterate accumulators and respective expressions together
- // 1.2 evaluate expressions
- // 1.3 update / merge accumulators with the expressions' values
+ // 1.2 filter the batch if necessary
+ // 1.3 evaluate expressions
+ // 1.4 update / merge accumulators with the expressions' values
// 1.1
accumulators
.iter_mut()
.zip(expressions)
- .try_for_each(|(accum, expr)| {
+ .zip(filters)
+ .try_for_each(|((accum, expr), filter)| {
// 1.2
+ let batch = match filter {
+ Some(filter) => Cow::Owned(batch_filter(&batch, filter)?),
+ None => Cow::Borrowed(&batch),
+ };
+ // 1.3
let values = &expr
.iter()
- .map(|e| e.evaluate(batch))
+ .map(|e| e.evaluate(&batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;
- // 1.3
+ // 1.4
let size_pre = accum.size();
let res = match mode {
AggregateMode::Partial => accum.update_batch(values),
diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
index 42ba9f8cb3..3cc2442543 100644
--- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
@@ -32,8 +32,8 @@ use futures::stream::{Stream, StreamExt};
use crate::execution::context::TaskContext;
use crate::execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use crate::physical_plan::aggregates::{
- evaluate_group_by, evaluate_many, group_schema, AccumulatorItem,
AggregateMode,
- PhysicalGroupBy, RowAccumulatorItem,
+ evaluate_group_by, evaluate_many, evaluate_optional, group_schema,
AccumulatorItem,
+ AggregateMode, PhysicalGroupBy, RowAccumulatorItem,
};
use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr};
@@ -41,9 +41,10 @@ use crate::physical_plan::{RecordBatchStream,
SendableRecordBatchStream};
use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
use arrow::array::{new_null_array, Array, ArrayRef, PrimitiveArray,
UInt32Builder};
-use arrow::compute::cast;
+use arrow::compute::{cast, filter};
use arrow::datatypes::{DataType, Schema, UInt32Type};
-use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
+use arrow::{compute, datatypes::SchemaRef, record_batch::RecordBatch};
+use datafusion_common::cast::as_boolean_array;
use datafusion_common::utils::get_arrayref_at_indices;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;
@@ -73,21 +74,26 @@ pub(crate) struct GroupedHashAggregateStream {
schema: SchemaRef,
input: SendableRecordBatchStream,
mode: AggregateMode,
- exec_state: ExecutionState,
+
normal_aggr_expr: Vec<Arc<dyn AggregateExpr>>,
- row_aggr_state: RowAggregationState,
/// Aggregate expressions not supporting row accumulation
normal_aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+ /// Filter expression for each normal aggregate expression
+ normal_filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
+
/// Aggregate expressions supporting row accumulation
row_aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
-
- group_by: PhysicalGroupBy,
+ /// Filter expression for each row aggregate expression
+ row_filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
row_accumulators: Vec<RowAccumulatorItem>,
-
row_converter: RowConverter,
row_aggr_schema: SchemaRef,
row_aggr_layout: Arc<RowLayout>,
+ group_by: PhysicalGroupBy,
+
+ aggr_state: AggregationState,
+ exec_state: ExecutionState,
baseline_metrics: BaselineMetrics,
random_state: RandomState,
/// size to be used for resulting RecordBatches
@@ -125,6 +131,7 @@ impl GroupedHashAggregateStream {
schema: SchemaRef,
group_by: PhysicalGroupBy,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+ filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
batch_size: usize,
@@ -137,15 +144,26 @@ impl GroupedHashAggregateStream {
let mut row_aggr_expr = vec![];
let mut row_agg_indices = vec![];
let mut row_aggregate_expressions = vec![];
+ let mut row_filter_expressions = vec![];
let mut normal_aggr_expr = vec![];
let mut normal_agg_indices = vec![];
let mut normal_aggregate_expressions = vec![];
+ let mut normal_filter_expressions = vec![];
// The expressions to evaluate the batch, one vec of expressions per
aggregation.
// Assuming create_schema() always puts group columns in front of
aggregation columns, we set
// col_idx_base to the group expression count.
let all_aggregate_expressions =
aggregates::aggregate_expressions(&aggr_expr, &mode, start_idx)?;
- for (expr, others) in
aggr_expr.iter().zip(all_aggregate_expressions.into_iter())
+ let filter_expressions = match mode {
+ AggregateMode::Partial => filter_expr,
+ AggregateMode::Final | AggregateMode::FinalPartitioned => {
+ vec![None; aggr_expr.len()]
+ }
+ };
+ for ((expr, others), filter) in aggr_expr
+ .iter()
+ .zip(all_aggregate_expressions.into_iter())
+ .zip(filter_expressions.into_iter())
{
let n_fields = match mode {
// In partial aggregation, we keep additional fields in order
to successfully
@@ -160,10 +178,12 @@ impl GroupedHashAggregateStream {
};
if expr.row_accumulator_supported() {
row_aggregate_expressions.push(others);
+ row_filter_expressions.push(filter.clone());
row_agg_indices.push(aggr_range);
row_aggr_expr.push(expr.clone());
} else {
normal_aggregate_expressions.push(others);
+ normal_filter_expressions.push(filter.clone());
normal_agg_indices.push(aggr_range);
normal_aggr_expr.push(expr.clone());
}
@@ -187,7 +207,7 @@ impl GroupedHashAggregateStream {
Arc::new(RowLayout::new(&row_aggr_schema, RowType::WordAligned));
let name = format!("GroupedHashAggregateStream[{partition}]");
- let row_aggr_state = RowAggregationState {
+ let aggr_state = AggregationState {
reservation:
MemoryConsumer::new(name).register(context.memory_pool()),
map: RawTable::with_capacity(0),
group_states: Vec::with_capacity(0),
@@ -199,19 +219,21 @@ impl GroupedHashAggregateStream {
Ok(GroupedHashAggregateStream {
schema: Arc::clone(&schema),
- mode,
- exec_state,
input,
- group_by,
+ mode,
normal_aggr_expr,
+ normal_aggregate_expressions,
+ normal_filter_expressions,
+ row_aggregate_expressions,
+ row_filter_expressions,
row_accumulators,
row_converter,
row_aggr_schema,
row_aggr_layout,
+ group_by,
+ aggr_state,
+ exec_state,
baseline_metrics,
- normal_aggregate_expressions,
- row_aggregate_expressions,
- row_aggr_state,
random_state: Default::default(),
batch_size,
row_group_skip_position: 0,
@@ -243,7 +265,7 @@ impl Stream for GroupedHashAggregateStream {
// This happens AFTER we actually used the memory,
but simplifies the whole accounting and we are OK with
// overshooting a bit. Also this means we either
store the whole record batch or not.
let result = result.and_then(|allocated| {
-
self.row_aggr_state.reservation.try_grow(allocated)
+ self.aggr_state.reservation.try_grow(allocated)
});
if let Err(e) = result {
@@ -312,25 +334,23 @@ impl GroupedHashAggregateStream {
let mut batch_hashes = vec![0; n_rows];
create_hashes(group_values, &self.random_state, &mut batch_hashes)?;
- let RowAggregationState {
- map: row_map,
- group_states: row_group_states,
- ..
- } = &mut self.row_aggr_state;
+ let AggregationState {
+ map, group_states, ..
+ } = &mut self.aggr_state;
for (row, hash) in batch_hashes.into_iter().enumerate() {
- let entry = row_map.get_mut(hash, |(_hash, group_idx)| {
+ let entry = map.get_mut(hash, |(_hash, group_idx)| {
// verify that a group that we are inserting with hash is
// actually the same key value as the group in
// existing_idx (aka group_values @ row)
- let group_state = &row_group_states[*group_idx];
+ let group_state = &group_states[*group_idx];
group_rows.row(row) == group_state.group_by_values.row()
});
match entry {
// Existing entry for this group value
Some((_hash, group_idx)) => {
- let group_state = &mut row_group_states[*group_idx];
+ let group_state = &mut group_states[*group_idx];
// 1.3
if group_state.indices.is_empty() {
@@ -344,7 +364,7 @@ impl GroupedHashAggregateStream {
let accumulator_set =
aggregates::create_accumulators(&self.normal_aggr_expr)?;
// Add new entry to group_states and save newly created
index
- let group_state = RowGroupState {
+ let group_state = GroupState {
group_by_values: group_rows.row(row).owned(),
aggregation_buffer: vec![
0;
@@ -353,9 +373,9 @@ impl GroupedHashAggregateStream {
accumulator_set,
indices: vec![row as u32], // 1.3
};
- let group_idx = row_group_states.len();
+ let group_idx = group_states.len();
- // NOTE: do NOT include the `RowGroupState` struct size in
here because this is captured by
+ // NOTE: do NOT include the `GroupState` struct size in
here because this is captured by
// `group_states` (see allocation down below)
*allocated += (std::mem::size_of::<u8>()
* group_state.group_by_values.as_ref().len())
@@ -373,13 +393,13 @@ impl GroupedHashAggregateStream {
.sum::<usize>();
// for hasher function, use precomputed hash value
- row_map.insert_accounted(
+ map.insert_accounted(
(hash, group_idx),
|(hash, _group_index)| *hash,
allocated,
);
- row_group_states.push_accounted(group_state, allocated);
+ group_states.push_accounted(group_state, allocated);
groups_with_rows.push(group_idx);
}
@@ -389,12 +409,15 @@ impl GroupedHashAggregateStream {
}
// Update the accumulator results, according to row_aggr_state.
+ #[allow(clippy::too_many_arguments)]
fn update_accumulators(
&mut self,
groups_with_rows: &[usize],
offsets: &[usize],
row_values: &[Vec<ArrayRef>],
normal_values: &[Vec<ArrayRef>],
+ row_filter_values: &[Option<ArrayRef>],
+ normal_filter_values: &[Option<ArrayRef>],
allocated: &mut usize,
) -> Result<()> {
// 2.1 for each key in this batch
@@ -406,24 +429,19 @@ impl GroupedHashAggregateStream {
.iter()
.zip(offsets.windows(2))
.try_for_each(|(group_idx, offsets)| {
- let group_state = &mut
self.row_aggr_state.group_states[*group_idx];
+ let group_state = &mut
self.aggr_state.group_states[*group_idx];
// 2.2
+ // Process row accumulators
self.row_accumulators
.iter_mut()
.zip(row_values.iter())
- .map(|(accumulator, aggr_array)| {
- (
- accumulator,
- aggr_array
- .iter()
- .map(|array| {
- // 2.3
- array.slice(offsets[0], offsets[1] -
offsets[0])
- })
- .collect::<Vec<ArrayRef>>(),
- )
- })
- .try_for_each(|(accumulator, values)| {
+ .zip(row_filter_values.iter())
+ .try_for_each(|((accumulator, aggr_array), filter_opt)| {
+ let values = slice_and_maybe_filter(
+ aggr_array,
+ filter_opt.as_ref(),
+ offsets,
+ )?;
let mut state_accessor =
RowAccessor::new_from_layout(self.row_aggr_layout.clone());
state_accessor
@@ -437,27 +455,19 @@ impl GroupedHashAggregateStream {
accumulator.merge_batch(&values, &mut
state_accessor)
}
}
- })
- // 2.5
- .and(Ok(()))?;
+ })?;
// normal accumulators
group_state
.accumulator_set
.iter_mut()
.zip(normal_values.iter())
- .map(|(accumulator, aggr_array)| {
- (
- accumulator,
- aggr_array
- .iter()
- .map(|array| {
- // 2.3
- array.slice(offsets[0], offsets[1] -
offsets[0])
- })
- .collect::<Vec<ArrayRef>>(),
- )
- })
- .try_for_each(|(accumulator, values)| {
+ .zip(normal_filter_values.iter())
+ .try_for_each(|((accumulator, aggr_array), filter_opt)| {
+ let values = slice_and_maybe_filter(
+ aggr_array,
+ filter_opt.as_ref(),
+ offsets,
+ )?;
let size_pre = accumulator.size();
let res = match self.mode {
AggregateMode::Partial =>
accumulator.update_batch(&values),
@@ -496,6 +506,9 @@ impl GroupedHashAggregateStream {
evaluate_many(&self.row_aggregate_expressions, &batch)?;
let normal_aggr_input_values =
evaluate_many(&self.normal_aggregate_expressions, &batch)?;
+ let row_filter_values =
evaluate_optional(&self.row_filter_expressions, &batch)?;
+ let normal_filter_values =
+ evaluate_optional(&self.normal_filter_expressions, &batch)?;
let row_converter_size_pre = self.row_converter.size();
for group_values in &group_by_values {
@@ -507,7 +520,7 @@ impl GroupedHashAggregateStream {
let mut offsets = vec![0];
let mut offset_so_far = 0;
for &group_idx in groups_with_rows.iter() {
- let indices =
&self.row_aggr_state.group_states[group_idx].indices;
+ let indices = &self.aggr_state.group_states[group_idx].indices;
batch_indices.append_slice(indices);
offset_so_far += indices.len();
offsets.push(offset_so_far);
@@ -517,11 +530,17 @@ impl GroupedHashAggregateStream {
let row_values = get_at_indices(&row_aggr_input_values,
&batch_indices)?;
let normal_values =
get_at_indices(&normal_aggr_input_values, &batch_indices)?;
+ let row_filter_values =
+ get_optional_filters(&row_filter_values, &batch_indices);
+ let normal_filter_values =
+ get_optional_filters(&normal_filter_values, &batch_indices);
self.update_accumulators(
&groups_with_rows,
&offsets,
&row_values,
&normal_values,
+ &row_filter_values,
+ &normal_filter_values,
&mut allocated,
)?;
}
@@ -535,7 +554,7 @@ impl GroupedHashAggregateStream {
/// The state that is built for each output group.
#[derive(Debug)]
-pub struct RowGroupState {
+pub struct GroupState {
/// The actual group by values, stored sequentially
group_by_values: OwnedRow,
@@ -551,7 +570,7 @@ pub struct RowGroupState {
}
/// The state of all the groups
-pub struct RowAggregationState {
+pub struct AggregationState {
pub reservation: MemoryReservation,
/// Logically maps group values to an index in `group_states`
@@ -564,10 +583,10 @@ pub struct RowAggregationState {
pub map: RawTable<(u64, usize)>,
/// State for each group
- pub group_states: Vec<RowGroupState>,
+ pub group_states: Vec<GroupState>,
}
-impl std::fmt::Debug for RowAggregationState {
+impl std::fmt::Debug for AggregationState {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
// hashes are not store inline, so could only get values
let map_string = "RawTable";
@@ -582,19 +601,19 @@ impl GroupedHashAggregateStream {
/// Create a RecordBatch with all group keys and accumulator' states or
values.
fn create_batch_from_map(&mut self) -> Result<Option<RecordBatch>> {
let skip_items = self.row_group_skip_position;
- if skip_items > self.row_aggr_state.group_states.len() {
+ if skip_items > self.aggr_state.group_states.len() {
return Ok(None);
}
- if self.row_aggr_state.group_states.is_empty() {
+ if self.aggr_state.group_states.is_empty() {
let schema = self.schema.clone();
return Ok(Some(RecordBatch::new_empty(schema)));
}
let end_idx = min(
skip_items + self.batch_size,
- self.row_aggr_state.group_states.len(),
+ self.aggr_state.group_states.len(),
);
- let group_state_chunk =
&self.row_aggr_state.group_states[skip_items..end_idx];
+ let group_state_chunk =
&self.aggr_state.group_states[skip_items..end_idx];
if group_state_chunk.is_empty() {
let schema = self.schema.clone();
@@ -648,8 +667,8 @@ impl GroupedHashAggregateStream {
for (field_idx, field) in
output_fields[start..end].iter().enumerate() {
let current = match self.mode {
AggregateMode::Partial => ScalarValue::iter_to_array(
- group_state_chunk.iter().map(|row_group_state| {
- row_group_state.accumulator_set[idx]
+ group_state_chunk.iter().map(|group_state| {
+ group_state.accumulator_set[idx]
.state()
.map(|v| v[field_idx].clone())
.expect("Unexpected accumulator state in hash
aggregate")
@@ -657,8 +676,8 @@ impl GroupedHashAggregateStream {
),
AggregateMode::Final | AggregateMode::FinalPartitioned => {
ScalarValue::iter_to_array(group_state_chunk.iter().map(
- |row_group_state| {
-
row_group_state.accumulator_set[idx].evaluate().expect(
+ |group_state| {
+
group_state.accumulator_set[idx].evaluate().expect(
"Unexpected accumulator state in hash
aggregate",
)
},
@@ -726,3 +745,47 @@ fn get_at_indices(
.map(|array| get_arrayref_at_indices(array, batch_indices))
.collect()
}
+
+fn get_optional_filters(
+ original_values: &[Option<Arc<dyn Array>>],
+ batch_indices: &PrimitiveArray<UInt32Type>,
+) -> Vec<Option<Arc<dyn Array>>> {
+ original_values
+ .iter()
+ .map(|array| {
+ array.as_ref().map(|array| {
+ compute::take(
+ array.as_ref(),
+ batch_indices,
+ None, // None: no index check
+ )
+ .unwrap()
+ })
+ })
+ .collect()
+}
+
+fn slice_and_maybe_filter(
+ aggr_array: &[ArrayRef],
+ filter_opt: Option<&Arc<dyn Array>>,
+ offsets: &[usize],
+) -> Result<Vec<ArrayRef>> {
+ let sliced_arrays: Vec<ArrayRef> = aggr_array
+ .iter()
+ .map(|array| array.slice(offsets[0], offsets[1] - offsets[0]))
+ .collect();
+
+ let filtered_arrays = match filter_opt.as_ref() {
+ Some(f) => {
+ let sliced = f.slice(offsets[0], offsets[1] - offsets[0]);
+ let filter_array = as_boolean_array(&sliced)?;
+
+ sliced_arrays
+ .iter()
+ .map(|array| filter(array, filter_array).unwrap())
+ .collect::<Vec<ArrayRef>>()
+ }
+ None => sliced_arrays,
+ };
+ Ok(filtered_arrays)
+}
diff --git a/datafusion/core/src/physical_plan/filter.rs
b/datafusion/core/src/physical_plan/filter.rs
index a72aa69d07..494d3fc869 100644
--- a/datafusion/core/src/physical_plan/filter.rs
+++ b/datafusion/core/src/physical_plan/filter.rs
@@ -235,7 +235,7 @@ struct FilterExecStream {
baseline_metrics: BaselineMetrics,
}
-fn batch_filter(
+pub(crate) fn batch_filter(
batch: &RecordBatch,
predicate: &Arc<dyn PhysicalExpr>,
) -> Result<RecordBatch> {
diff --git a/datafusion/core/src/physical_plan/planner.rs
b/datafusion/core/src/physical_plan/planner.rs
index 2064357f7d..8ee32b8d04 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -627,10 +627,10 @@ impl DefaultPhysicalPlanner {
&physical_input_schema,
session_state)?;
- let aggregates = aggr_expr
+ let agg_filter = aggr_expr
.iter()
.map(|e| {
- create_aggregate_expr(
+ create_aggregate_expr_and_maybe_filter(
e,
logical_input_schema,
&physical_input_schema,
@@ -638,11 +638,13 @@ impl DefaultPhysicalPlanner {
)
})
.collect::<Result<Vec<_>>>()?;
+ let (aggregates, filters): (Vec<_>, Vec<_>) =
agg_filter.into_iter().unzip();
let initial_aggr = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
aggregates.clone(),
+ filters.clone(),
input_exec,
physical_input_schema.clone(),
)?);
@@ -678,6 +680,7 @@ impl DefaultPhysicalPlanner {
next_partition_mode,
final_grouping_set,
aggregates,
+ filters,
initial_aggr,
physical_input_schema.clone(),
)?))
@@ -1609,20 +1612,23 @@ pub fn create_window_expr(
)
}
+type AggregateExprWithOptionalFilter =
+ (Arc<dyn AggregateExpr>, Option<Arc<dyn PhysicalExpr>>);
+
/// Create an aggregate expression with a name from a logical expression
-pub fn create_aggregate_expr_with_name(
+pub fn create_aggregate_expr_with_name_and_maybe_filter(
e: &Expr,
name: impl Into<String>,
logical_input_schema: &DFSchema,
physical_input_schema: &Schema,
execution_props: &ExecutionProps,
-) -> Result<Arc<dyn AggregateExpr>> {
+) -> Result<AggregateExprWithOptionalFilter> {
match e {
Expr::AggregateFunction(AggregateFunction {
fun,
distinct,
args,
- ..
+ filter,
}) => {
let args = args
.iter()
@@ -1635,15 +1641,25 @@ pub fn create_aggregate_expr_with_name(
)
})
.collect::<Result<Vec<_>>>()?;
- aggregates::create_aggregate_expr(
+ let filter = match filter {
+ Some(e) => Some(create_physical_expr(
+ e,
+ logical_input_schema,
+ physical_input_schema,
+ execution_props,
+ )?),
+ None => None,
+ };
+ let agg_expr = aggregates::create_aggregate_expr(
fun,
*distinct,
&args,
physical_input_schema,
name,
- )
+ );
+ Ok((agg_expr?, filter))
}
- Expr::AggregateUDF { fun, args, .. } => {
+ Expr::AggregateUDF { fun, args, filter } => {
let args = args
.iter()
.map(|e| {
@@ -1656,7 +1672,19 @@ pub fn create_aggregate_expr_with_name(
})
.collect::<Result<Vec<_>>>()?;
- udaf::create_aggregate_expr(fun, &args, physical_input_schema,
name)
+ let filter = match filter {
+ Some(e) => Some(create_physical_expr(
+ e,
+ logical_input_schema,
+ physical_input_schema,
+ execution_props,
+ )?),
+ None => None,
+ };
+
+ let agg_expr =
+ udaf::create_aggregate_expr(fun, &args, physical_input_schema,
name);
+ Ok((agg_expr?, filter))
}
other => Err(DataFusionError::Internal(format!(
"Invalid aggregate expression '{other:?}'"
@@ -1665,19 +1693,19 @@ pub fn create_aggregate_expr_with_name(
}
/// Create an aggregate expression from a logical expression or an alias
-pub fn create_aggregate_expr(
+pub fn create_aggregate_expr_and_maybe_filter(
e: &Expr,
logical_input_schema: &DFSchema,
physical_input_schema: &Schema,
execution_props: &ExecutionProps,
-) -> Result<Arc<dyn AggregateExpr>> {
+) -> Result<AggregateExprWithOptionalFilter> {
// unpack (nested) aliased logical expressions, e.g. "sum(col) as total"
let (name, e) = match e {
Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()),
_ => (physical_name(e)?, e),
};
- create_aggregate_expr_with_name(
+ create_aggregate_expr_with_name_and_maybe_filter(
e,
name,
logical_input_schema,
@@ -1788,7 +1816,10 @@ impl DefaultPhysicalPlanner {
"Input physical plan:\n{}\n",
displayable(plan.as_ref()).indent()
);
- trace!("Detailed input physical plan:\n{:?}", plan);
+ trace!(
+ "Detailed input physical plan:\n{}",
+ displayable(plan.as_ref()).indent()
+ );
let mut new_plan = plan;
for optimizer in optimizers {
diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index b049e4b16c..10368341d8 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -1538,3 +1538,83 @@ query RT
select avg(c1), arrow_typeof(avg(c1)) from d_table
----
5 Decimal128(14, 7)
+
+# Use PostgresSQL dialect
+statement ok
+set datafusion.sql_parser.dialect = 'Postgres';
+
+# Creating the table
+statement ok
+CREATE TABLE test_table (c1 INT, c2 INT, c3 INT)
+
+# Inserting data
+statement ok
+INSERT INTO test_table VALUES (1, 10, 50), (1, 20, 60), (2, 10, 70), (2, 20,
80), (3, 10, NULL)
+
+# query_group_by_with_filter
+query II rowsort
+SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) as result FROM test_table GROUP BY
c1
+----
+1 20
+2 20
+3 NULL
+
+# query_group_by_avg_with_filter
+query IR rowsort
+SELECT c1, AVG(c2) FILTER (WHERE c2 >= 20) AS avg_c2 FROM test_table GROUP BY
c1
+----
+1 20
+2 20
+3 NULL
+
+# query_group_by_with_multiple_filters
+query IIR rowsort
+SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2, AVG(c3) FILTER (WHERE c3
<= 70) AS avg_c3 FROM test_table GROUP BY c1
+----
+1 20 55
+2 20 70
+3 NULL NULL
+
+# query_group_by_distinct_with_filter
+query II rowsort
+SELECT c1, COUNT(DISTINCT c2) FILTER (WHERE c2 >= 20) AS distinct_c2_count
FROM test_table GROUP BY c1
+----
+1 1
+2 1
+3 0
+
+# query_without_group_by_with_filter
+query I rowsort
+SELECT SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2 FROM test_table
+----
+40
+
+# count_without_group_by_with_filter
+query I rowsort
+SELECT COUNT(c2) FILTER (WHERE c2 >= 20) AS count_c2 FROM test_table
+----
+2
+
+# query_with_and_without_filter
+query III rowsort
+SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) as result, SUM(c2) as
result_no_filter FROM test_table GROUP BY c1;
+----
+1 20 30
+2 20 30
+3 NULL 10
+
+# query_filter_on_different_column_than_aggregate
+query I rowsort
+select sum(c1) FILTER (WHERE c2 < 30) from test_table;
+----
+9
+
+# query_test_empty_filter
+query I rowsort
+SELECT SUM(c2) FILTER (WHERE c2 >= 20000000) AS sum_c2 FROM test_table;
+----
+NULL
+
+# Restore the default dialect
+statement ok
+set datafusion.sql_parser.dialect = 'Generic';
diff --git
a/datafusion/core/tests/sqllogictests/test_files/information_schema.slt
b/datafusion/core/tests/sqllogictests/test_files/information_schema.slt
index 3adf5585d7..80187564f9 100644
--- a/datafusion/core/tests/sqllogictests/test_files/information_schema.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/information_schema.slt
@@ -163,6 +163,7 @@ datafusion.optimizer.repartition_sorts true
datafusion.optimizer.repartition_windows true
datafusion.optimizer.skip_failed_rules true
datafusion.optimizer.top_down_join_key_reordering true
+datafusion.sql_parser.dialect generic
datafusion.sql_parser.enable_ident_normalization true
datafusion.sql_parser.parse_float_as_decimal false
diff --git a/datafusion/expr/src/tree_node/expr.rs
b/datafusion/expr/src/tree_node/expr.rs
index 61a5c91fec..b0a5e31da0 100644
--- a/datafusion/expr/src/tree_node/expr.rs
+++ b/datafusion/expr/src/tree_node/expr.rs
@@ -297,7 +297,7 @@ impl TreeNode for Expr {
fun,
transform_vec(args, &mut transform)?,
distinct,
- filter,
+ transform_option_box(filter, &mut transform)?,
)),
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) =>
Expr::GroupingSet(GroupingSet::Rollup(
@@ -318,7 +318,7 @@ impl TreeNode for Expr {
Expr::AggregateUDF { args, fun, filter } => Expr::AggregateUDF {
args: transform_vec(args, &mut transform)?,
fun,
- filter,
+ filter: transform_option_box(filter, &mut transform)?,
},
Expr::InList {
expr,
diff --git a/datafusion/optimizer/src/push_down_projection.rs
b/datafusion/optimizer/src/push_down_projection.rs
index fd8f4c011a..97ba5a92d7 100644
--- a/datafusion/optimizer/src/push_down_projection.rs
+++ b/datafusion/optimizer/src/push_down_projection.rs
@@ -1030,7 +1030,7 @@ mod tests {
)?
.build()?;
- let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b),
COUNT(test.b) FILTER (WHERE c > Int32(42)) AS count2]]\
+ let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b),
COUNT(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\
\n TableScan: test projection=[a, b, c]";
assert_optimized_plan_eq(&plan, expected)
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 875a6a15cf..c4b3ac2114 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1258,6 +1258,10 @@ message WindowAggExecNode {
Schema input_schema = 4;
}
+message MaybeFilter {
+ PhysicalExprNode expr = 1;
+}
+
message AggregateExecNode {
repeated PhysicalExprNode group_expr = 1;
repeated PhysicalExprNode aggr_expr = 2;
@@ -1269,6 +1273,7 @@ message AggregateExecNode {
Schema input_schema = 7;
repeated PhysicalExprNode null_expr = 8;
repeated bool groups = 9;
+ repeated MaybeFilter filter_expr = 10;
}
message GlobalLimitExecNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 63a8a2ed00..105591a000 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -33,6 +33,9 @@ impl serde::Serialize for AggregateExecNode {
if !self.groups.is_empty() {
len += 1;
}
+ if !self.filter_expr.is_empty() {
+ len += 1;
+ }
let mut struct_ser =
serializer.serialize_struct("datafusion.AggregateExecNode", len)?;
if !self.group_expr.is_empty() {
struct_ser.serialize_field("groupExpr", &self.group_expr)?;
@@ -63,6 +66,9 @@ impl serde::Serialize for AggregateExecNode {
if !self.groups.is_empty() {
struct_ser.serialize_field("groups", &self.groups)?;
}
+ if !self.filter_expr.is_empty() {
+ struct_ser.serialize_field("filterExpr", &self.filter_expr)?;
+ }
struct_ser.end()
}
}
@@ -88,6 +94,8 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
"null_expr",
"nullExpr",
"groups",
+ "filter_expr",
+ "filterExpr",
];
#[allow(clippy::enum_variant_names)]
@@ -101,6 +109,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
InputSchema,
NullExpr,
Groups,
+ FilterExpr,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -131,6 +140,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
"inputSchema" | "input_schema" =>
Ok(GeneratedField::InputSchema),
"nullExpr" | "null_expr" =>
Ok(GeneratedField::NullExpr),
"groups" => Ok(GeneratedField::Groups),
+ "filterExpr" | "filter_expr" =>
Ok(GeneratedField::FilterExpr),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -159,6 +169,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
let mut input_schema__ = None;
let mut null_expr__ = None;
let mut groups__ = None;
+ let mut filter_expr__ = None;
while let Some(k) = map.next_key()? {
match k {
GeneratedField::GroupExpr => {
@@ -215,6 +226,12 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
}
groups__ = Some(map.next_value()?);
}
+ GeneratedField::FilterExpr => {
+ if filter_expr__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("filterExpr"));
+ }
+ filter_expr__ = Some(map.next_value()?);
+ }
}
}
Ok(AggregateExecNode {
@@ -227,6 +244,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
input_schema: input_schema__,
null_expr: null_expr__.unwrap_or_default(),
groups: groups__.unwrap_or_default(),
+ filter_expr: filter_expr__.unwrap_or_default(),
})
}
}
@@ -11280,6 +11298,97 @@ impl<'de> serde::Deserialize<'de> for Map {
deserializer.deserialize_struct("datafusion.Map", FIELDS,
GeneratedVisitor)
}
}
+impl serde::Serialize for MaybeFilter {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if self.expr.is_some() {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion.MaybeFilter", len)?;
+ if let Some(v) = self.expr.as_ref() {
+ struct_ser.serialize_field("expr", v)?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for MaybeFilter {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "expr",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ Expr,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "expr" => Ok(GeneratedField::Expr),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = MaybeFilter;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct datafusion.MaybeFilter")
+ }
+
+ fn visit_map<V>(self, mut map: V) ->
std::result::Result<MaybeFilter, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut expr__ = None;
+ while let Some(k) = map.next_key()? {
+ match k {
+ GeneratedField::Expr => {
+ if expr__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("expr"));
+ }
+ expr__ = map.next_value()?;
+ }
+ }
+ }
+ Ok(MaybeFilter {
+ expr: expr__,
+ })
+ }
+ }
+ deserializer.deserialize_struct("datafusion.MaybeFilter", FIELDS,
GeneratedVisitor)
+ }
+}
impl serde::Serialize for NegativeNode {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 7764fe7848..ebdd14b2f3 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1791,6 +1791,12 @@ pub struct WindowAggExecNode {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct MaybeFilter {
+ #[prost(message, optional, tag = "1")]
+ pub expr: ::core::option::Option<PhysicalExprNode>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
pub struct AggregateExecNode {
#[prost(message, repeated, tag = "1")]
pub group_expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
@@ -1811,6 +1817,8 @@ pub struct AggregateExecNode {
pub null_expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
#[prost(bool, repeated, tag = "9")]
pub groups: ::prost::alloc::vec::Vec<bool>,
+ #[prost(message, repeated, tag = "10")]
+ pub filter_expr: ::prost::alloc::vec::Vec<MaybeFilter>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 8fd57f002b..ff13bbfb8f 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -403,6 +403,18 @@ impl AsExecutionPlan for PhysicalPlanNode {
let physical_schema: SchemaRef =
SchemaRef::new((&input_schema).try_into()?);
+ let physical_filter_expr = hash_agg
+ .filter_expr
+ .iter()
+ .map(|expr| {
+ let x = expr
+ .expr
+ .as_ref()
+ .map(|e| parse_physical_expr(e, registry,
&physical_schema));
+ x.transpose()
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
let physical_aggr_expr: Vec<Arc<dyn AggregateExpr>> = hash_agg
.aggr_expr
.iter()
@@ -450,6 +462,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
agg_mode,
PhysicalGroupBy::new(group_expr, null_expr, groups),
physical_aggr_expr,
+ physical_filter_expr,
input,
Arc::new((&input_schema).try_into()?),
)?))
@@ -864,6 +877,12 @@ impl AsExecutionPlan for PhysicalPlanNode {
.map(|expr| expr.1.to_owned())
.collect();
+ let filter = exec
+ .filter_expr()
+ .iter()
+ .map(|expr| expr.to_owned().try_into())
+ .collect::<Result<Vec<_>>>()?;
+
let agg = exec
.aggr_expr()
.iter()
@@ -911,6 +930,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
group_expr,
group_expr_name: group_names,
aggr_expr: agg,
+ filter_expr: filter,
aggr_expr_name: agg_names,
mode: agg_mode as i32,
input: Some(Box::new(input)),
@@ -1391,6 +1411,7 @@ mod roundtrip_tests {
AggregateMode::Final,
PhysicalGroupBy::new_single(groups.clone()),
aggregates.clone(),
+ vec![None],
Arc::new(EmptyExec::new(false, schema.clone())),
schema,
)?))
@@ -1601,6 +1622,7 @@ mod roundtrip_tests {
AggregateMode::Final,
PhysicalGroupBy::new_single(groups),
aggregates.clone(),
+ vec![None],
Arc::new(EmptyExec::new(false, schema.clone())),
schema,
)?))
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index 9210a2d7fa..e18932575c 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -498,3 +498,16 @@ impl From<JoinSide> for protobuf::JoinSide {
}
}
}
+
+impl TryFrom<Option<Arc<dyn PhysicalExpr>>> for protobuf::MaybeFilter {
+ type Error = DataFusionError;
+
+ fn try_from(expr: Option<Arc<dyn PhysicalExpr>>) -> Result<Self,
Self::Error> {
+ match expr {
+ None => Ok(protobuf::MaybeFilter { expr: None }),
+ Some(expr) => Ok(protobuf::MaybeFilter {
+ expr: Some(expr.try_into()?),
+ }),
+ }
+ }
+}
diff --git a/datafusion/sql/tests/integration_test.rs
b/datafusion/sql/tests/integration_test.rs
index 3749e65573..64ca85b72d 100644
--- a/datafusion/sql/tests/integration_test.rs
+++ b/datafusion/sql/tests/integration_test.rs
@@ -3061,8 +3061,8 @@ fn hive_aggregate_with_filter() -> Result<()> {
let dialect = &HiveDialect {};
let sql = "SELECT SUM(age) FILTER (WHERE age > 4) FROM person";
let plan = logical_plan_with_dialect(sql, dialect)?;
- let expected = "Projection: SUM(person.age) FILTER (WHERE age > Int64(4))\
- \n Aggregate: groupBy=[[]], aggr=[[SUM(person.age) FILTER (WHERE age
> Int64(4))]]\
+ let expected = "Projection: SUM(person.age) FILTER (WHERE person.age >
Int64(4))\
+ \n Aggregate: groupBy=[[]], aggr=[[SUM(person.age) FILTER (WHERE
person.age > Int64(4))]]\
\n TableScan: person"
.to_string();
assert_eq!(plan.display_indent().to_string(), expected);
diff --git a/docs/source/user-guide/configs.md
b/docs/source/user-guide/configs.md
index 749a0bcb06..dc21c81942 100644
--- a/docs/source/user-guide/configs.md
+++ b/docs/source/user-guide/configs.md
@@ -73,3 +73,4 @@ Environment variables are read during `SessionConfig`
initialisation so they mus
| datafusion.explain.physical_plan_only | false |
When set to true, the explain statement will only print physical plans
[...]
| datafusion.sql_parser.parse_float_as_decimal | false |
When set to true, SQL parser will parse float as decimal type
[...]
| datafusion.sql_parser.enable_ident_normalization | true |
When set to true, SQL parser will normalize ident (convert ident to lowercase
when not quoted)
[...]
+| datafusion.sql_parser.dialect | generic |
Configure the SQL dialect used by DataFusion's parser; supported values
include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL,
ClickHouse, BigQuery, and Ansi.
[...]