This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 02427674ee Add `value_from_statisics` to AggregateUDFImpl, remove
special case for min/max/count aggregate statistics (#12296)
02427674ee is described below
commit 02427674eef658b9b0acb7142f18e8c1520bdb17
Author: Edmondo Porcu <[email protected]>
AuthorDate: Mon Sep 30 15:09:40 2024 -0400
Add `value_from_statisics` to AggregateUDFImpl, remove special case for
min/max/count aggregate statistics (#12296)
* Removes min/max/count comparison based on name in aggregate statistics
* Abstracting away value from statistics
* Removing imports
* Introduced StatisticsArgs
* Fixed docs
---
datafusion/expr/src/lib.rs | 2 +-
datafusion/expr/src/udaf.rs | 27 ++-
datafusion/functions-aggregate/src/count.rs | 35 +++-
datafusion/functions-aggregate/src/min_max.rs | 77 ++++++++-
.../physical-optimizer/src/aggregate_statistics.rs | 182 ++-------------------
datafusion/physical-plan/src/lib.rs | 1 +
6 files changed, 154 insertions(+), 170 deletions(-)
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index 32eac90c3e..7d94a3b93e 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -90,7 +90,7 @@ pub use logical_plan::*;
pub use partition_evaluator::PartitionEvaluator;
pub use sqlparser;
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
-pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF};
+pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs};
pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udwf::{ReversedUDWF, WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index e3ef672daf..d8592bce60 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -26,7 +26,8 @@ use std::vec;
use arrow::datatypes::{DataType, Field};
-use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
+use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue,
Statistics};
+use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use crate::expr::AggregateFunction;
use crate::function::{
@@ -94,6 +95,19 @@ impl fmt::Display for AggregateUDF {
}
}
+pub struct StatisticsArgs<'a> {
+ pub statistics: &'a Statistics,
+ pub return_type: &'a DataType,
+ /// Whether the aggregate function is distinct.
+ ///
+ /// ```sql
+ /// SELECT COUNT(DISTINCT column1) FROM t;
+ /// ```
+ pub is_distinct: bool,
+ /// The physical expression of arguments the aggregate function takes.
+ pub exprs: &'a [Arc<dyn PhysicalExpr>],
+}
+
impl AggregateUDF {
/// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
///
@@ -244,6 +258,13 @@ impl AggregateUDF {
self.inner.is_descending()
}
+ pub fn value_from_stats(
+ &self,
+ statistics_args: &StatisticsArgs,
+ ) -> Option<ScalarValue> {
+ self.inner.value_from_stats(statistics_args)
+ }
+
/// See [`AggregateUDFImpl::default_value`] for more details.
pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
self.inner.default_value(data_type)
@@ -556,6 +577,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn is_descending(&self) -> Option<bool> {
None
}
+ // Return the value of the current UDF from the statistics
+ fn value_from_stats(&self, _statistics_args: &StatisticsArgs) ->
Option<ScalarValue> {
+ None
+ }
/// Returns default value of the function given the input is all `null`.
///
diff --git a/datafusion/functions-aggregate/src/count.rs
b/datafusion/functions-aggregate/src/count.rs
index 417e28e72a..cc245b3572 100644
--- a/datafusion/functions-aggregate/src/count.rs
+++ b/datafusion/functions-aggregate/src/count.rs
@@ -16,7 +16,9 @@
// under the License.
use ahash::RandomState;
+use datafusion_common::stats::Precision;
use
datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
+use datafusion_physical_expr::expressions;
use std::collections::HashSet;
use std::ops::BitAnd;
use std::{fmt::Debug, sync::Arc};
@@ -46,7 +48,7 @@ use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator,
AggregateUDFImpl,
EmitTo, GroupsAccumulator, Signature, Volatility,
};
-use datafusion_expr::{Expr, ReversedUDAF, TypeSignature};
+use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
use datafusion_functions_aggregate_common::aggregate::count_distinct::{
BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
PrimitiveDistinctCountAccumulator,
@@ -54,6 +56,7 @@ use
datafusion_functions_aggregate_common::aggregate::count_distinct::{
use
datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
use datafusion_physical_expr_common::binary_map::OutputType;
+use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
make_udaf_expr_and_func!(
Count,
count,
@@ -291,6 +294,36 @@ impl AggregateUDFImpl for Count {
fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(0)))
}
+
+ fn value_from_stats(&self, statistics_args: &StatisticsArgs) ->
Option<ScalarValue> {
+ if statistics_args.is_distinct {
+ return None;
+ }
+ if let Precision::Exact(num_rows) =
statistics_args.statistics.num_rows {
+ if statistics_args.exprs.len() == 1 {
+ // TODO optimize with exprs other than Column
+ if let Some(col_expr) = statistics_args.exprs[0]
+ .as_any()
+ .downcast_ref::<expressions::Column>()
+ {
+ let current_val =
&statistics_args.statistics.column_statistics
+ [col_expr.index()]
+ .null_count;
+ if let &Precision::Exact(val) = current_val {
+ return Some(ScalarValue::Int64(Some((num_rows - val)
as i64)));
+ }
+ } else if let Some(lit_expr) = statistics_args.exprs[0]
+ .as_any()
+ .downcast_ref::<expressions::Literal>()
+ {
+ if lit_expr.value() == &COUNT_STAR_EXPANSION {
+ return Some(ScalarValue::Int64(Some(num_rows as i64)));
+ }
+ }
+ }
+ }
+ None
+ }
}
#[derive(Debug)]
diff --git a/datafusion/functions-aggregate/src/min_max.rs
b/datafusion/functions-aggregate/src/min_max.rs
index 961e863960..1ce1abe09e 100644
--- a/datafusion/functions-aggregate/src/min_max.rs
+++ b/datafusion/functions-aggregate/src/min_max.rs
@@ -15,7 +15,7 @@
// under the License.
//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function
-//! [`Min`] and [`MinAccumulator`] accumulator for the `max` function
+//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
@@ -49,10 +49,12 @@ use arrow::datatypes::{
UInt8Type,
};
use arrow_schema::IntervalUnit;
+use datafusion_common::stats::Precision;
use datafusion_common::{
- downcast_value, exec_err, internal_err, DataFusionError, Result,
+ downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError,
Result,
};
use
datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
+use datafusion_physical_expr::expressions;
use std::fmt::Debug;
use arrow::datatypes::i256;
@@ -63,10 +65,10 @@ use arrow::datatypes::{
};
use datafusion_common::ScalarValue;
-use datafusion_expr::GroupsAccumulator;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature,
Volatility,
};
+use datafusion_expr::{GroupsAccumulator, StatisticsArgs};
use half::f16;
use std::ops::Deref;
@@ -147,6 +149,54 @@ macro_rules! instantiate_min_accumulator {
}};
}
+trait FromColumnStatistics {
+ fn value_from_column_statistics(
+ &self,
+ stats: &ColumnStatistics,
+ ) -> Option<ScalarValue>;
+
+ fn value_from_statistics(
+ &self,
+ statistics_args: &StatisticsArgs,
+ ) -> Option<ScalarValue> {
+ if let Precision::Exact(num_rows) =
&statistics_args.statistics.num_rows {
+ match *num_rows {
+ 0 => return
ScalarValue::try_from(statistics_args.return_type).ok(),
+ value if value > 0 => {
+ let col_stats =
&statistics_args.statistics.column_statistics;
+ if statistics_args.exprs.len() == 1 {
+ // TODO optimize with exprs other than Column
+ if let Some(col_expr) = statistics_args.exprs[0]
+ .as_any()
+ .downcast_ref::<expressions::Column>()
+ {
+ return self.value_from_column_statistics(
+ &col_stats[col_expr.index()],
+ );
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ None
+ }
+}
+
+impl FromColumnStatistics for Max {
+ fn value_from_column_statistics(
+ &self,
+ col_stats: &ColumnStatistics,
+ ) -> Option<ScalarValue> {
+ if let Precision::Exact(ref val) = col_stats.max_value {
+ if !val.is_null() {
+ return Some(val.clone());
+ }
+ }
+ None
+ }
+}
+
impl AggregateUDFImpl for Max {
fn as_any(&self) -> &dyn std::any::Any {
self
@@ -272,6 +322,7 @@ impl AggregateUDFImpl for Max {
fn is_descending(&self) -> Option<bool> {
Some(true)
}
+
fn order_sensitivity(&self) ->
datafusion_expr::utils::AggregateOrderSensitivity {
datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
}
@@ -282,6 +333,9 @@ impl AggregateUDFImpl for Max {
fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
datafusion_expr::ReversedUDAF::Identical
}
+ fn value_from_stats(&self, statistics_args: &StatisticsArgs) ->
Option<ScalarValue> {
+ self.value_from_statistics(statistics_args)
+ }
}
// Statically-typed version of min/max(array) -> ScalarValue for string types
@@ -926,6 +980,20 @@ impl Default for Min {
}
}
+impl FromColumnStatistics for Min {
+ fn value_from_column_statistics(
+ &self,
+ col_stats: &ColumnStatistics,
+ ) -> Option<ScalarValue> {
+ if let Precision::Exact(ref val) = col_stats.min_value {
+ if !val.is_null() {
+ return Some(val.clone());
+ }
+ }
+ None
+ }
+}
+
impl AggregateUDFImpl for Min {
fn as_any(&self) -> &dyn std::any::Any {
self
@@ -1052,6 +1120,9 @@ impl AggregateUDFImpl for Min {
Some(false)
}
+ fn value_from_stats(&self, statistics_args: &StatisticsArgs) ->
Option<ScalarValue> {
+ self.value_from_statistics(statistics_args)
+ }
fn order_sensitivity(&self) ->
datafusion_expr::utils::AggregateOrderSensitivity {
datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
}
diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs
b/datafusion/physical-optimizer/src/aggregate_statistics.rs
index 71f129be98..a11b498b95 100644
--- a/datafusion/physical-optimizer/src/aggregate_statistics.rs
+++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs
@@ -23,14 +23,12 @@ use datafusion_common::scalar::ScalarValue;
use datafusion_common::Result;
use datafusion_physical_plan::aggregates::AggregateExec;
use datafusion_physical_plan::projection::ProjectionExec;
-use datafusion_physical_plan::{expressions, ExecutionPlan, Statistics};
+use datafusion_physical_plan::{expressions, ExecutionPlan};
use crate::PhysicalOptimizerRule;
-use datafusion_common::stats::Precision;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
-use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
-use datafusion_physical_plan::udaf::AggregateFunctionExpr;
+use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs};
/// Optimizer that uses available statistics for aggregate functions
#[derive(Default, Debug)]
@@ -57,14 +55,19 @@ impl PhysicalOptimizerRule for AggregateStatistics {
let stats = partial_agg_exec.input().statistics()?;
let mut projections = vec![];
for expr in partial_agg_exec.aggr_expr() {
- if let Some((non_null_rows, name)) =
- take_optimizable_column_and_table_count(expr, &stats)
+ let field = expr.field();
+ let args = expr.expressions();
+ let statistics_args = StatisticsArgs {
+ statistics: &stats,
+ return_type: field.data_type(),
+ is_distinct: expr.is_distinct(),
+ exprs: args.as_slice(),
+ };
+ if let Some((optimizable_statistic, name)) =
+ take_optimizable_value_from_statistics(&statistics_args,
expr)
{
- projections.push((expressions::lit(non_null_rows),
name.to_owned()));
- } else if let Some((min, name)) = take_optimizable_min(expr,
&stats) {
- projections.push((expressions::lit(min), name.to_owned()));
- } else if let Some((max, name)) = take_optimizable_max(expr,
&stats) {
- projections.push((expressions::lit(max), name.to_owned()));
+ projections
+ .push((expressions::lit(optimizable_statistic),
name.to_owned()));
} else {
// TODO: we need all aggr_expr to be resolved (cf TODO
fullres)
break;
@@ -135,160 +138,11 @@ fn take_optimizable(node: &dyn ExecutionPlan) ->
Option<Arc<dyn ExecutionPlan>>
None
}
-/// If this agg_expr is a count that can be exactly derived from the
statistics, return it.
-fn take_optimizable_column_and_table_count(
- agg_expr: &AggregateFunctionExpr,
- stats: &Statistics,
-) -> Option<(ScalarValue, String)> {
- let col_stats = &stats.column_statistics;
- if is_non_distinct_count(agg_expr) {
- if let Precision::Exact(num_rows) = stats.num_rows {
- let exprs = agg_expr.expressions();
- if exprs.len() == 1 {
- // TODO optimize with exprs other than Column
- if let Some(col_expr) =
- exprs[0].as_any().downcast_ref::<expressions::Column>()
- {
- let current_val = &col_stats[col_expr.index()].null_count;
- if let &Precision::Exact(val) = current_val {
- return Some((
- ScalarValue::Int64(Some((num_rows - val) as i64)),
- agg_expr.name().to_string(),
- ));
- }
- } else if let Some(lit_expr) =
- exprs[0].as_any().downcast_ref::<expressions::Literal>()
- {
- if lit_expr.value() == &COUNT_STAR_EXPANSION {
- return Some((
- ScalarValue::Int64(Some(num_rows as i64)),
- agg_expr.name().to_string(),
- ));
- }
- }
- }
- }
- }
- None
-}
-
-/// If this agg_expr is a min that is exactly defined in the statistics,
return it.
-fn take_optimizable_min(
- agg_expr: &AggregateFunctionExpr,
- stats: &Statistics,
-) -> Option<(ScalarValue, String)> {
- if let Precision::Exact(num_rows) = &stats.num_rows {
- match *num_rows {
- 0 => {
- // MIN/MAX with 0 rows is always null
- if is_min(agg_expr) {
- if let Ok(min_data_type) =
- ScalarValue::try_from(agg_expr.field().data_type())
- {
- return Some((min_data_type,
agg_expr.name().to_string()));
- }
- }
- }
- value if value > 0 => {
- let col_stats = &stats.column_statistics;
- if is_min(agg_expr) {
- let exprs = agg_expr.expressions();
- if exprs.len() == 1 {
- // TODO optimize with exprs other than Column
- if let Some(col_expr) =
-
exprs[0].as_any().downcast_ref::<expressions::Column>()
- {
- if let Precision::Exact(val) =
- &col_stats[col_expr.index()].min_value
- {
- if !val.is_null() {
- return Some((
- val.clone(),
- agg_expr.name().to_string(),
- ));
- }
- }
- }
- }
- }
- }
- _ => {}
- }
- }
- None
-}
-
/// If this agg_expr is a max that is exactly defined in the statistics,
return it.
-fn take_optimizable_max(
+fn take_optimizable_value_from_statistics(
+ statistics_args: &StatisticsArgs,
agg_expr: &AggregateFunctionExpr,
- stats: &Statistics,
) -> Option<(ScalarValue, String)> {
- if let Precision::Exact(num_rows) = &stats.num_rows {
- match *num_rows {
- 0 => {
- // MIN/MAX with 0 rows is always null
- if is_max(agg_expr) {
- if let Ok(max_data_type) =
- ScalarValue::try_from(agg_expr.field().data_type())
- {
- return Some((max_data_type,
agg_expr.name().to_string()));
- }
- }
- }
- value if value > 0 => {
- let col_stats = &stats.column_statistics;
- if is_max(agg_expr) {
- let exprs = agg_expr.expressions();
- if exprs.len() == 1 {
- // TODO optimize with exprs other than Column
- if let Some(col_expr) =
-
exprs[0].as_any().downcast_ref::<expressions::Column>()
- {
- if let Precision::Exact(val) =
- &col_stats[col_expr.index()].max_value
- {
- if !val.is_null() {
- return Some((
- val.clone(),
- agg_expr.name().to_string(),
- ));
- }
- }
- }
- }
- }
- }
- _ => {}
- }
- }
- None
-}
-
-// TODO: Move this check into AggregateUDFImpl
-// https://github.com/apache/datafusion/issues/11153
-fn is_non_distinct_count(agg_expr: &AggregateFunctionExpr) -> bool {
- if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() {
- return true;
- }
- false
+ let value = agg_expr.fun().value_from_stats(statistics_args);
+ value.map(|val| (val, agg_expr.name().to_string()))
}
-
-// TODO: Move this check into AggregateUDFImpl
-// https://github.com/apache/datafusion/issues/11153
-fn is_min(agg_expr: &AggregateFunctionExpr) -> bool {
- if agg_expr.fun().name().to_lowercase() == "min" {
- return true;
- }
- false
-}
-
-// TODO: Move this check into AggregateUDFImpl
-// https://github.com/apache/datafusion/issues/11153
-fn is_max(agg_expr: &AggregateFunctionExpr) -> bool {
- if agg_expr.fun().name().to_lowercase() == "max" {
- return true;
- }
- false
-}
-
-// See tests in datafusion/core/tests/physical_optimizer
diff --git a/datafusion/physical-plan/src/lib.rs
b/datafusion/physical-plan/src/lib.rs
index 7cbfd49afb..845a74eaea 100644
--- a/datafusion/physical-plan/src/lib.rs
+++ b/datafusion/physical-plan/src/lib.rs
@@ -82,6 +82,7 @@ pub mod windows;
pub mod work_table;
pub mod udaf {
+ pub use datafusion_expr::StatisticsArgs;
pub use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]