This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 02427674ee Add `value_from_statisics` to AggregateUDFImpl, remove 
special case for min/max/count aggregate statistics (#12296)
02427674ee is described below

commit 02427674eef658b9b0acb7142f18e8c1520bdb17
Author: Edmondo Porcu <[email protected]>
AuthorDate: Mon Sep 30 15:09:40 2024 -0400

    Add `value_from_statisics` to AggregateUDFImpl, remove special case for 
min/max/count aggregate statistics (#12296)
    
    * Removes min/max/count comparison based on name in aggregate statistics
    
    * Abstracting away value from statistics
    
    * Removing imports
    
    * Introduced StatisticsArgs
    
    * Fixed docs
---
 datafusion/expr/src/lib.rs                         |   2 +-
 datafusion/expr/src/udaf.rs                        |  27 ++-
 datafusion/functions-aggregate/src/count.rs        |  35 +++-
 datafusion/functions-aggregate/src/min_max.rs      |  77 ++++++++-
 .../physical-optimizer/src/aggregate_statistics.rs | 182 ++-------------------
 datafusion/physical-plan/src/lib.rs                |   1 +
 6 files changed, 154 insertions(+), 170 deletions(-)

diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index 32eac90c3e..7d94a3b93e 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -90,7 +90,7 @@ pub use logical_plan::*;
 pub use partition_evaluator::PartitionEvaluator;
 pub use sqlparser;
 pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
-pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF};
+pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs};
 pub use udf::{ScalarUDF, ScalarUDFImpl};
 pub use udwf::{ReversedUDWF, WindowUDF, WindowUDFImpl};
 pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index e3ef672daf..d8592bce60 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -26,7 +26,8 @@ use std::vec;
 
 use arrow::datatypes::{DataType, Field};
 
-use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
+use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, 
Statistics};
+use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
 
 use crate::expr::AggregateFunction;
 use crate::function::{
@@ -94,6 +95,19 @@ impl fmt::Display for AggregateUDF {
     }
 }
 
+pub struct StatisticsArgs<'a> {
+    pub statistics: &'a Statistics,
+    pub return_type: &'a DataType,
+    /// Whether the aggregate function is distinct.
+    ///
+    /// ```sql
+    /// SELECT COUNT(DISTINCT column1) FROM t;
+    /// ```
+    pub is_distinct: bool,
+    /// The physical expression of arguments the aggregate function takes.
+    pub exprs: &'a [Arc<dyn PhysicalExpr>],
+}
+
 impl AggregateUDF {
     /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
     ///
@@ -244,6 +258,13 @@ impl AggregateUDF {
         self.inner.is_descending()
     }
 
+    pub fn value_from_stats(
+        &self,
+        statistics_args: &StatisticsArgs,
+    ) -> Option<ScalarValue> {
+        self.inner.value_from_stats(statistics_args)
+    }
+
     /// See [`AggregateUDFImpl::default_value`] for more details.
     pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
         self.inner.default_value(data_type)
@@ -556,6 +577,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
     fn is_descending(&self) -> Option<bool> {
         None
     }
+    // Return the value of the current UDF from the statistics
+    fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> 
Option<ScalarValue> {
+        None
+    }
 
     /// Returns default value of the function given the input is all `null`.
     ///
diff --git a/datafusion/functions-aggregate/src/count.rs 
b/datafusion/functions-aggregate/src/count.rs
index 417e28e72a..cc245b3572 100644
--- a/datafusion/functions-aggregate/src/count.rs
+++ b/datafusion/functions-aggregate/src/count.rs
@@ -16,7 +16,9 @@
 // under the License.
 
 use ahash::RandomState;
+use datafusion_common::stats::Precision;
 use 
datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
+use datafusion_physical_expr::expressions;
 use std::collections::HashSet;
 use std::ops::BitAnd;
 use std::{fmt::Debug, sync::Arc};
@@ -46,7 +48,7 @@ use datafusion_expr::{
     function::AccumulatorArgs, utils::format_state_name, Accumulator, 
AggregateUDFImpl,
     EmitTo, GroupsAccumulator, Signature, Volatility,
 };
-use datafusion_expr::{Expr, ReversedUDAF, TypeSignature};
+use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
 use datafusion_functions_aggregate_common::aggregate::count_distinct::{
     BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
     PrimitiveDistinctCountAccumulator,
@@ -54,6 +56,7 @@ use 
datafusion_functions_aggregate_common::aggregate::count_distinct::{
 use 
datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
 use datafusion_physical_expr_common::binary_map::OutputType;
 
+use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
 make_udaf_expr_and_func!(
     Count,
     count,
@@ -291,6 +294,36 @@ impl AggregateUDFImpl for Count {
     fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
         Ok(ScalarValue::Int64(Some(0)))
     }
+
+    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> 
Option<ScalarValue> {
+        if statistics_args.is_distinct {
+            return None;
+        }
+        if let Precision::Exact(num_rows) = 
statistics_args.statistics.num_rows {
+            if statistics_args.exprs.len() == 1 {
+                // TODO optimize with exprs other than Column
+                if let Some(col_expr) = statistics_args.exprs[0]
+                    .as_any()
+                    .downcast_ref::<expressions::Column>()
+                {
+                    let current_val = 
&statistics_args.statistics.column_statistics
+                        [col_expr.index()]
+                    .null_count;
+                    if let &Precision::Exact(val) = current_val {
+                        return Some(ScalarValue::Int64(Some((num_rows - val) 
as i64)));
+                    }
+                } else if let Some(lit_expr) = statistics_args.exprs[0]
+                    .as_any()
+                    .downcast_ref::<expressions::Literal>()
+                {
+                    if lit_expr.value() == &COUNT_STAR_EXPANSION {
+                        return Some(ScalarValue::Int64(Some(num_rows as i64)));
+                    }
+                }
+            }
+        }
+        None
+    }
 }
 
 #[derive(Debug)]
diff --git a/datafusion/functions-aggregate/src/min_max.rs 
b/datafusion/functions-aggregate/src/min_max.rs
index 961e863960..1ce1abe09e 100644
--- a/datafusion/functions-aggregate/src/min_max.rs
+++ b/datafusion/functions-aggregate/src/min_max.rs
@@ -15,7 +15,7 @@
 // under the License.
 
 //! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function
-//! [`Min`] and [`MinAccumulator`] accumulator for the `max` function
+//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function
 
 // distributed with this work for additional information
 // regarding copyright ownership.  The ASF licenses this file
@@ -49,10 +49,12 @@ use arrow::datatypes::{
     UInt8Type,
 };
 use arrow_schema::IntervalUnit;
+use datafusion_common::stats::Precision;
 use datafusion_common::{
-    downcast_value, exec_err, internal_err, DataFusionError, Result,
+    downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, 
Result,
 };
 use 
datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
+use datafusion_physical_expr::expressions;
 use std::fmt::Debug;
 
 use arrow::datatypes::i256;
@@ -63,10 +65,10 @@ use arrow::datatypes::{
 };
 
 use datafusion_common::ScalarValue;
-use datafusion_expr::GroupsAccumulator;
 use datafusion_expr::{
     function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, 
Volatility,
 };
+use datafusion_expr::{GroupsAccumulator, StatisticsArgs};
 use half::f16;
 use std::ops::Deref;
 
@@ -147,6 +149,54 @@ macro_rules! instantiate_min_accumulator {
     }};
 }
 
+trait FromColumnStatistics {
+    fn value_from_column_statistics(
+        &self,
+        stats: &ColumnStatistics,
+    ) -> Option<ScalarValue>;
+
+    fn value_from_statistics(
+        &self,
+        statistics_args: &StatisticsArgs,
+    ) -> Option<ScalarValue> {
+        if let Precision::Exact(num_rows) = 
&statistics_args.statistics.num_rows {
+            match *num_rows {
+                0 => return 
ScalarValue::try_from(statistics_args.return_type).ok(),
+                value if value > 0 => {
+                    let col_stats = 
&statistics_args.statistics.column_statistics;
+                    if statistics_args.exprs.len() == 1 {
+                        // TODO optimize with exprs other than Column
+                        if let Some(col_expr) = statistics_args.exprs[0]
+                            .as_any()
+                            .downcast_ref::<expressions::Column>()
+                        {
+                            return self.value_from_column_statistics(
+                                &col_stats[col_expr.index()],
+                            );
+                        }
+                    }
+                }
+                _ => {}
+            }
+        }
+        None
+    }
+}
+
+impl FromColumnStatistics for Max {
+    fn value_from_column_statistics(
+        &self,
+        col_stats: &ColumnStatistics,
+    ) -> Option<ScalarValue> {
+        if let Precision::Exact(ref val) = col_stats.max_value {
+            if !val.is_null() {
+                return Some(val.clone());
+            }
+        }
+        None
+    }
+}
+
 impl AggregateUDFImpl for Max {
     fn as_any(&self) -> &dyn std::any::Any {
         self
@@ -272,6 +322,7 @@ impl AggregateUDFImpl for Max {
     fn is_descending(&self) -> Option<bool> {
         Some(true)
     }
+
     fn order_sensitivity(&self) -> 
datafusion_expr::utils::AggregateOrderSensitivity {
         datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
     }
@@ -282,6 +333,9 @@ impl AggregateUDFImpl for Max {
     fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
         datafusion_expr::ReversedUDAF::Identical
     }
+    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> 
Option<ScalarValue> {
+        self.value_from_statistics(statistics_args)
+    }
 }
 
 // Statically-typed version of min/max(array) -> ScalarValue for string types
@@ -926,6 +980,20 @@ impl Default for Min {
     }
 }
 
+impl FromColumnStatistics for Min {
+    fn value_from_column_statistics(
+        &self,
+        col_stats: &ColumnStatistics,
+    ) -> Option<ScalarValue> {
+        if let Precision::Exact(ref val) = col_stats.min_value {
+            if !val.is_null() {
+                return Some(val.clone());
+            }
+        }
+        None
+    }
+}
+
 impl AggregateUDFImpl for Min {
     fn as_any(&self) -> &dyn std::any::Any {
         self
@@ -1052,6 +1120,9 @@ impl AggregateUDFImpl for Min {
         Some(false)
     }
 
+    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> 
Option<ScalarValue> {
+        self.value_from_statistics(statistics_args)
+    }
     fn order_sensitivity(&self) -> 
datafusion_expr::utils::AggregateOrderSensitivity {
         datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
     }
diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs 
b/datafusion/physical-optimizer/src/aggregate_statistics.rs
index 71f129be98..a11b498b95 100644
--- a/datafusion/physical-optimizer/src/aggregate_statistics.rs
+++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs
@@ -23,14 +23,12 @@ use datafusion_common::scalar::ScalarValue;
 use datafusion_common::Result;
 use datafusion_physical_plan::aggregates::AggregateExec;
 use datafusion_physical_plan::projection::ProjectionExec;
-use datafusion_physical_plan::{expressions, ExecutionPlan, Statistics};
+use datafusion_physical_plan::{expressions, ExecutionPlan};
 
 use crate::PhysicalOptimizerRule;
-use datafusion_common::stats::Precision;
 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
-use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
 use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
-use datafusion_physical_plan::udaf::AggregateFunctionExpr;
+use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs};
 
 /// Optimizer that uses available statistics for aggregate functions
 #[derive(Default, Debug)]
@@ -57,14 +55,19 @@ impl PhysicalOptimizerRule for AggregateStatistics {
             let stats = partial_agg_exec.input().statistics()?;
             let mut projections = vec![];
             for expr in partial_agg_exec.aggr_expr() {
-                if let Some((non_null_rows, name)) =
-                    take_optimizable_column_and_table_count(expr, &stats)
+                let field = expr.field();
+                let args = expr.expressions();
+                let statistics_args = StatisticsArgs {
+                    statistics: &stats,
+                    return_type: field.data_type(),
+                    is_distinct: expr.is_distinct(),
+                    exprs: args.as_slice(),
+                };
+                if let Some((optimizable_statistic, name)) =
+                    take_optimizable_value_from_statistics(&statistics_args, 
expr)
                 {
-                    projections.push((expressions::lit(non_null_rows), 
name.to_owned()));
-                } else if let Some((min, name)) = take_optimizable_min(expr, 
&stats) {
-                    projections.push((expressions::lit(min), name.to_owned()));
-                } else if let Some((max, name)) = take_optimizable_max(expr, 
&stats) {
-                    projections.push((expressions::lit(max), name.to_owned()));
+                    projections
+                        .push((expressions::lit(optimizable_statistic), 
name.to_owned()));
                 } else {
                     // TODO: we need all aggr_expr to be resolved (cf TODO 
fullres)
                     break;
@@ -135,160 +138,11 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> 
Option<Arc<dyn ExecutionPlan>>
     None
 }
 
-/// If this agg_expr is a count that can be exactly derived from the 
statistics, return it.
-fn take_optimizable_column_and_table_count(
-    agg_expr: &AggregateFunctionExpr,
-    stats: &Statistics,
-) -> Option<(ScalarValue, String)> {
-    let col_stats = &stats.column_statistics;
-    if is_non_distinct_count(agg_expr) {
-        if let Precision::Exact(num_rows) = stats.num_rows {
-            let exprs = agg_expr.expressions();
-            if exprs.len() == 1 {
-                // TODO optimize with exprs other than Column
-                if let Some(col_expr) =
-                    exprs[0].as_any().downcast_ref::<expressions::Column>()
-                {
-                    let current_val = &col_stats[col_expr.index()].null_count;
-                    if let &Precision::Exact(val) = current_val {
-                        return Some((
-                            ScalarValue::Int64(Some((num_rows - val) as i64)),
-                            agg_expr.name().to_string(),
-                        ));
-                    }
-                } else if let Some(lit_expr) =
-                    exprs[0].as_any().downcast_ref::<expressions::Literal>()
-                {
-                    if lit_expr.value() == &COUNT_STAR_EXPANSION {
-                        return Some((
-                            ScalarValue::Int64(Some(num_rows as i64)),
-                            agg_expr.name().to_string(),
-                        ));
-                    }
-                }
-            }
-        }
-    }
-    None
-}
-
-/// If this agg_expr is a min that is exactly defined in the statistics, 
return it.
-fn take_optimizable_min(
-    agg_expr: &AggregateFunctionExpr,
-    stats: &Statistics,
-) -> Option<(ScalarValue, String)> {
-    if let Precision::Exact(num_rows) = &stats.num_rows {
-        match *num_rows {
-            0 => {
-                // MIN/MAX with 0 rows is always null
-                if is_min(agg_expr) {
-                    if let Ok(min_data_type) =
-                        ScalarValue::try_from(agg_expr.field().data_type())
-                    {
-                        return Some((min_data_type, 
agg_expr.name().to_string()));
-                    }
-                }
-            }
-            value if value > 0 => {
-                let col_stats = &stats.column_statistics;
-                if is_min(agg_expr) {
-                    let exprs = agg_expr.expressions();
-                    if exprs.len() == 1 {
-                        // TODO optimize with exprs other than Column
-                        if let Some(col_expr) =
-                            
exprs[0].as_any().downcast_ref::<expressions::Column>()
-                        {
-                            if let Precision::Exact(val) =
-                                &col_stats[col_expr.index()].min_value
-                            {
-                                if !val.is_null() {
-                                    return Some((
-                                        val.clone(),
-                                        agg_expr.name().to_string(),
-                                    ));
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-            _ => {}
-        }
-    }
-    None
-}
-
 /// If this agg_expr is a max that is exactly defined in the statistics, 
return it.
-fn take_optimizable_max(
+fn take_optimizable_value_from_statistics(
+    statistics_args: &StatisticsArgs,
     agg_expr: &AggregateFunctionExpr,
-    stats: &Statistics,
 ) -> Option<(ScalarValue, String)> {
-    if let Precision::Exact(num_rows) = &stats.num_rows {
-        match *num_rows {
-            0 => {
-                // MIN/MAX with 0 rows is always null
-                if is_max(agg_expr) {
-                    if let Ok(max_data_type) =
-                        ScalarValue::try_from(agg_expr.field().data_type())
-                    {
-                        return Some((max_data_type, 
agg_expr.name().to_string()));
-                    }
-                }
-            }
-            value if value > 0 => {
-                let col_stats = &stats.column_statistics;
-                if is_max(agg_expr) {
-                    let exprs = agg_expr.expressions();
-                    if exprs.len() == 1 {
-                        // TODO optimize with exprs other than Column
-                        if let Some(col_expr) =
-                            
exprs[0].as_any().downcast_ref::<expressions::Column>()
-                        {
-                            if let Precision::Exact(val) =
-                                &col_stats[col_expr.index()].max_value
-                            {
-                                if !val.is_null() {
-                                    return Some((
-                                        val.clone(),
-                                        agg_expr.name().to_string(),
-                                    ));
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-            _ => {}
-        }
-    }
-    None
-}
-
-// TODO: Move this check into AggregateUDFImpl
-// https://github.com/apache/datafusion/issues/11153
-fn is_non_distinct_count(agg_expr: &AggregateFunctionExpr) -> bool {
-    if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() {
-        return true;
-    }
-    false
+    let value = agg_expr.fun().value_from_stats(statistics_args);
+    value.map(|val| (val, agg_expr.name().to_string()))
 }
-
-// TODO: Move this check into AggregateUDFImpl
-// https://github.com/apache/datafusion/issues/11153
-fn is_min(agg_expr: &AggregateFunctionExpr) -> bool {
-    if agg_expr.fun().name().to_lowercase() == "min" {
-        return true;
-    }
-    false
-}
-
-// TODO: Move this check into AggregateUDFImpl
-// https://github.com/apache/datafusion/issues/11153
-fn is_max(agg_expr: &AggregateFunctionExpr) -> bool {
-    if agg_expr.fun().name().to_lowercase() == "max" {
-        return true;
-    }
-    false
-}
-
-// See tests in datafusion/core/tests/physical_optimizer
diff --git a/datafusion/physical-plan/src/lib.rs 
b/datafusion/physical-plan/src/lib.rs
index 7cbfd49afb..845a74eaea 100644
--- a/datafusion/physical-plan/src/lib.rs
+++ b/datafusion/physical-plan/src/lib.rs
@@ -82,6 +82,7 @@ pub mod windows;
 pub mod work_table;
 
 pub mod udaf {
+    pub use datafusion_expr::StatisticsArgs;
     pub use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to