This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 929eb6d86 Support number of centroids in approx_percentile_cont (#3146)
929eb6d86 is described below

commit 929eb6d860fb60ba994b24397ad3c3eb7d839cdf
Author: Yang Jiang <[email protected]>
AuthorDate: Wed Aug 17 17:15:37 2022 +0800

    Support number of centroids in approx_percentile_cont (#3146)
    
    * Support number of histogram bins in approx_percentile_cont
    
    * add args check and UT
    
    * add doc
---
 datafusion/core/tests/sql/aggregates.rs            |  48 ++++++++
 datafusion/expr/src/aggregate_function.rs          |  40 +++++--
 .../src/aggregate/approx_percentile_cont.rs        | 126 ++++++++++++++++-----
 datafusion/physical-expr/src/aggregate/build_in.rs |  21 +++-
 docs/source/user-guide/sql/aggregate_functions.md  |   6 +
 5 files changed, 201 insertions(+), 40 deletions(-)

diff --git a/datafusion/core/tests/sql/aggregates.rs 
b/datafusion/core/tests/sql/aggregates.rs
index 7d02d8cb5..9be41afa2 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -925,6 +925,54 @@ async fn csv_query_approx_percentile_cont_with_weight() -> 
Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn csv_query_approx_percentile_cont_with_histogram_bins() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv(&ctx).await?;
+
+    // compare approx_percentile_cont and approx_percentile_cont_with_weight
+    let sql = "SELECT c1, approx_percentile_cont(c3, 0.95, 200) AS c3_p95 FROM 
aggregate_test_100 GROUP BY 1 ORDER BY 1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | c3_p95 |",
+        "+----+--------+",
+        "| a  | 73     |",
+        "| b  | 68     |",
+        "| c  | 122    |",
+        "| d  | 124    |",
+        "| e  | 115    |",
+        "+----+--------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let results = plan_and_collect(
+        &ctx,
+        "SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM 
aggregate_test_100 GROUP BY 1 ORDER BY 1",
+    )
+        .await
+        .unwrap_err();
+    assert_eq!(results.to_string(), "This feature is not implemented: Tdigest 
max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data 
type Int64).");
+
+    let results = plan_and_collect(
+        &ctx,
+        "SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100",
+    )
+    .await
+    .unwrap_err();
+    assert_eq!(results.to_string(), "Error during planning: The percentile 
sample points count for ApproxPercentileCont must be integer, not Utf8.");
+
+    let results = plan_and_collect(
+        &ctx,
+        "SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM 
aggregate_test_100",
+    )
+    .await
+    .unwrap_err();
+    assert_eq!(results.to_string(), "Error during planning: The percentile 
sample points count for ApproxPercentileCont must be integer, not Float64.");
+
+    Ok(())
+}
+
 #[tokio::test]
 async fn csv_query_sum_crossjoin() {
     let ctx = SessionContext::new();
diff --git a/datafusion/expr/src/aggregate_function.rs 
b/datafusion/expr/src/aggregate_function.rs
index 71c598e42..7b8616921 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -310,6 +310,12 @@ pub fn coerce_types(
                     agg_fun, input_types[1]
                 )));
             }
+            if input_types.len() == 3 && !is_integer_arg_type(&input_types[2]) 
{
+                return Err(DataFusionError::Plan(format!(
+                        "The percentile sample points count for {:?} must be 
integer, not {:?}.",
+                        agg_fun, input_types[2]
+                    )));
+            }
             Ok(input_types.to_vec())
         }
         AggregateFunction::ApproxPercentileContWithWeight => {
@@ -382,14 +388,20 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
         AggregateFunction::Correlation => {
             Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
         }
-        AggregateFunction::ApproxPercentileCont => Signature::one_of(
+        AggregateFunction::ApproxPercentileCont => {
             // Accept any numeric value paired with a float64 percentile
-            NUMERICS
-                .iter()
-                .map(|t| TypeSignature::Exact(vec![t.clone(), 
DataType::Float64]))
-                .collect(),
-            Volatility::Immutable,
-        ),
+            let with_tdigest_size = NUMERICS.iter().map(|t| {
+                TypeSignature::Exact(vec![t.clone(), DataType::Float64, 
t.clone()])
+            });
+            Signature::one_of(
+                NUMERICS
+                    .iter()
+                    .map(|t| TypeSignature::Exact(vec![t.clone(), 
DataType::Float64]))
+                    .chain(with_tdigest_size)
+                    .collect(),
+                Volatility::Immutable,
+            )
+        }
         AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of(
             // Accept any numeric value paired with a float64 percentile
             NUMERICS
@@ -702,6 +714,20 @@ pub fn is_correlation_support_arg_type(arg_type: 
&DataType) -> bool {
     )
 }
 
+pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
+    matches!(
+        arg_type,
+        DataType::UInt8
+            | DataType::UInt16
+            | DataType::UInt32
+            | DataType::UInt64
+            | DataType::Int8
+            | DataType::Int16
+            | DataType::Int32
+            | DataType::Int64
+    )
+}
+
 /// Return `true` if `arg_type` is of a [`DataType`] that the
 /// [`AggregateFunction::ApproxPercentileCont`] aggregation can operate on.
 pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> 
bool {
diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs 
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
index 41c6c72db..ee32b0a6a 100644
--- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
@@ -40,6 +40,7 @@ pub struct ApproxPercentileCont {
     input_data_type: DataType,
     expr: Vec<Arc<dyn PhysicalExpr>>,
     percentile: f64,
+    tdigest_max_size: Option<usize>,
 }
 
 impl ApproxPercentileCont {
@@ -52,39 +53,35 @@ impl ApproxPercentileCont {
         // Arguments should be [ColumnExpr, DesiredPercentileLiteral]
         debug_assert_eq!(expr.len(), 2);
 
-        // Extract the desired percentile literal
-        let lit = expr[1]
-            .as_any()
-            .downcast_ref::<Literal>()
-            .ok_or_else(|| {
-                DataFusionError::Internal(
-                    "desired percentile argument must be float 
literal".to_string(),
-                )
-            })?
-            .value();
-        let percentile = match lit {
-            ScalarValue::Float32(Some(q)) => *q as f64,
-            ScalarValue::Float64(Some(q)) => *q as f64,
-            got => return Err(DataFusionError::NotImplemented(format!(
-                "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 
or Float64 literal (got data type {})",
-                got
-            )))
-        };
+        let percentile = validate_input_percentile_expr(&expr[1])?;
 
-        // Ensure the percentile is between 0 and 1.
-        if !(0.0..=1.0).contains(&percentile) {
-            return Err(DataFusionError::Plan(format!(
-                "Percentile value must be between 0.0 and 1.0 inclusive, {} is 
invalid",
-                percentile
-            )));
-        }
+        Ok(Self {
+            name: name.into(),
+            input_data_type,
+            // The physical expr to evaluate during accumulation
+            expr,
+            percentile,
+            tdigest_max_size: None,
+        })
+    }
 
+    /// Create a new [`ApproxPercentileCont`] aggregate function.
+    pub fn new_with_max_size(
+        expr: Vec<Arc<dyn PhysicalExpr>>,
+        name: impl Into<String>,
+        input_data_type: DataType,
+    ) -> Result<Self> {
+        // Arguments should be [ColumnExpr, DesiredPercentileLiteral, 
TDigestMaxSize]
+        debug_assert_eq!(expr.len(), 3);
+        let percentile = validate_input_percentile_expr(&expr[1])?;
+        let max_size = validate_input_max_size_expr(&expr[2])?;
         Ok(Self {
             name: name.into(),
             input_data_type,
             // The physical expr to evaluate during accumulation
             expr,
             percentile,
+            tdigest_max_size: Some(max_size),
         })
     }
 
@@ -100,7 +97,13 @@ impl ApproxPercentileCont {
             | DataType::Int64
             | DataType::Float32
             | DataType::Float64) => {
-                ApproxPercentileAccumulator::new(self.percentile, t.clone())
+                if let Some(max_size) = self.tdigest_max_size {
+                    
ApproxPercentileAccumulator::new_with_max_size(self.percentile, t.clone(), 
max_size)
+
+                }else{
+                    ApproxPercentileAccumulator::new(self.percentile, 
t.clone())
+
+                }
             }
             other => {
                 return Err(DataFusionError::NotImplemented(format!(
@@ -113,6 +116,64 @@ impl ApproxPercentileCont {
     }
 }
 
+fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> 
{
+    // Extract the desired percentile literal
+    let lit = expr
+        .as_any()
+        .downcast_ref::<Literal>()
+        .ok_or_else(|| {
+            DataFusionError::Internal(
+                "desired percentile argument must be float 
literal".to_string(),
+            )
+        })?
+        .value();
+    let percentile = match lit {
+        ScalarValue::Float32(Some(q)) => *q as f64,
+        ScalarValue::Float64(Some(q)) => *q as f64,
+        got => return Err(DataFusionError::NotImplemented(format!(
+            "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or 
Float64 literal (got data type {})",
+            got.get_datatype()
+        )))
+    };
+
+    // Ensure the percentile is between 0 and 1.
+    if !(0.0..=1.0).contains(&percentile) {
+        return Err(DataFusionError::Plan(format!(
+            "Percentile value must be between 0.0 and 1.0 inclusive, {} is 
invalid",
+            percentile
+        )));
+    }
+    Ok(percentile)
+}
+
+fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> 
{
+    // Extract the desired percentile literal
+    let lit = expr
+        .as_any()
+        .downcast_ref::<Literal>()
+        .ok_or_else(|| {
+            DataFusionError::Internal(
+                "desired percentile argument must be float 
literal".to_string(),
+            )
+        })?
+        .value();
+    let max_size = match lit {
+        ScalarValue::UInt8(Some(q)) => *q as usize,
+        ScalarValue::UInt16(Some(q)) => *q as usize,
+        ScalarValue::UInt32(Some(q)) => *q as usize,
+        ScalarValue::UInt64(Some(q)) => *q as usize,
+        ScalarValue::Int32(Some(q)) if *q > 0 => *q as usize,
+        ScalarValue::Int64(Some(q)) if *q > 0 => *q as usize,
+        ScalarValue::Int16(Some(q)) if *q > 0 => *q as usize,
+        ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize,
+        got => return Err(DataFusionError::NotImplemented(format!(
+            "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt 
> 0 literal (got data type {}).",
+            got.get_datatype()
+        )))
+    };
+    Ok(max_size)
+}
+
 impl AggregateExpr for ApproxPercentileCont {
     fn as_any(&self) -> &dyn Any {
         self
@@ -190,6 +251,18 @@ impl ApproxPercentileAccumulator {
         }
     }
 
+    pub fn new_with_max_size(
+        percentile: f64,
+        return_type: DataType,
+        max_size: usize,
+    ) -> Self {
+        Self {
+            digest: TDigest::new(max_size),
+            percentile,
+            return_type,
+        }
+    }
+
     pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) {
         self.digest = TDigest::merge_digests(digests);
     }
@@ -285,7 +358,6 @@ impl ApproxPercentileAccumulator {
         }
     }
 }
-
 impl Accumulator for ApproxPercentileAccumulator {
     fn state(&self) -> Result<Vec<AggregateState>> {
         Ok(self
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs 
b/datafusion/physical-expr/src/aggregate/build_in.rs
index f47982bec..e6635698c 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -216,12 +216,21 @@ pub fn create_aggregate_expr(
             ));
         }
         (AggregateFunction::ApproxPercentileCont, false) => {
-            Arc::new(expressions::ApproxPercentileCont::new(
-                // Pass in the desired percentile expr
-                coerced_phy_exprs,
-                name,
-                return_type,
-            )?)
+            if coerced_phy_exprs.len() == 2 {
+                Arc::new(expressions::ApproxPercentileCont::new(
+                    // Pass in the desired percentile expr
+                    coerced_phy_exprs,
+                    name,
+                    return_type,
+                )?)
+            } else {
+                Arc::new(expressions::ApproxPercentileCont::new_with_max_size(
+                    // Pass in the desired percentile expr
+                    coerced_phy_exprs,
+                    name,
+                    return_type,
+                )?)
+            }
         }
         (AggregateFunction::ApproxPercentileCont, true) => {
             return Err(DataFusionError::NotImplemented(
diff --git a/docs/source/user-guide/sql/aggregate_functions.md 
b/docs/source/user-guide/sql/aggregate_functions.md
index d3472a7f5..e8299b619 100644
--- a/docs/source/user-guide/sql/aggregate_functions.md
+++ b/docs/source/user-guide/sql/aggregate_functions.md
@@ -53,6 +53,12 @@ Aggregate functions operate on a set of values to compute a 
single result. Pleas
 
 It supports raw data as input and build Tdigest sketches during query time, 
and is approximately equal to `approx_percentile_cont_with_weight(x, 1, p)`.
 
+`approx_percentile_cont(x, p, n) -> x` return the approximate percentile 
(TDigest) of input values, where `p` is a float64 between 0 and 1 (inclusive),
+
+and `n` (default 100) is the number of centroids in Tdigest which means that 
if there are `n` or fewer unique values in `x`, you can expect an exact result.
+
+A higher value of `n` results in a more accurate approximation and the cost of 
higher memory usage.
+
 ### approx_percentile_cont_with_weight
 
 `approx_percentile_cont_with_weight(x, w, p) -> x` returns the approximate 
percentile (TDigest) of input values with weight, where `w` is weight column 
expression and `p` is a float64 between 0 and 1 (inclusive).

Reply via email to