This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a9561a0f06 Add regr_slope() aggregate function (#7135)
a9561a0f06 is described below

commit a9561a0f06c25f370dc39df08d057db85c4e0c7a
Author: Yongting You <[email protected]>
AuthorDate: Tue Aug 1 13:33:23 2023 -0700

    Add regr_slope() aggregate function (#7135)
---
 .../tests/sqllogictests/test_files/aggregate.slt   | 158 ++++++++++
 datafusion/expr/src/aggregate_function.rs          |  13 +-
 datafusion/expr/src/type_coercion/aggregates.rs    |  35 +--
 datafusion/physical-expr/src/aggregate/build_in.rs |  11 +
 datafusion/physical-expr/src/aggregate/mod.rs      |   1 +
 .../physical-expr/src/aggregate/regr_slope.rs      | 331 +++++++++++++++++++++
 datafusion/physical-expr/src/expressions/mod.rs    |   1 +
 datafusion/proto/proto/datafusion.proto            |   1 +
 datafusion/proto/src/generated/pbjson.rs           |   3 +
 datafusion/proto/src/generated/prost.rs            |   3 +
 datafusion/proto/src/logical_plan/from_proto.rs    |   1 +
 datafusion/proto/src/logical_plan/to_proto.rs      |   4 +
 docs/source/user-guide/sql/aggregate_functions.md  |  17 ++
 13 files changed, 550 insertions(+), 29 deletions(-)

diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt 
b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index 2f6f44c56b..0e3f337071 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -2290,3 +2290,161 @@ true
 false
 true
 NULL
+
+
+
+#
+# regr_slope() tests
+#
+
+# invalid input
+statement error
+select regr_slope();
+
+statement error
+select regr_slope(*);
+
+statement error
+select regr_slope(*) from aggregate_test_100;
+
+statement error
+select regr_slope(1);
+
+statement error
+select regr_slope(1,2,3);
+
+statement error
+select regr_slope(1, 'foo');
+
+statement error
+select regr_slope('foo', 1);
+
+statement error
+select regr_slope('foo', 'bar');
+
+
+
+# regr_slope() NULL result
+query R
+select regr_slope(1,1);
+----
+NULL
+
+query R
+select regr_slope(1, NULL);
+----
+NULL
+
+query R
+select regr_slope(NULL, 1);
+----
+NULL
+
+query R
+select regr_slope(NULL, NULL);
+----
+NULL
+
+query R
+select regr_slope(column2, column1) from (values (1,2), (1,4), (1,6));
+----
+NULL
+
+
+
+# regr_slope() basic tests
+query R
+select regr_slope(column2, column1) from (values (1,2), (2,4), (3,6));
+----
+2
+
+query R
+select regr_slope(c12, c11) from aggregate_test_100;
+----
+0.051534002628
+
+
+
+# regr_slope() ignore NULLs
+query R
+select regr_slope(column2, column1) from (values (1,NULL), (2,4), (3,6));
+----
+2
+
+query R
+select regr_slope(column2, column1) from (values (1,NULL), (NULL,4), (3,6));
+----
+NULL
+
+query R
+select regr_slope(column2, column1) from (values (1,NULL), (NULL,4), 
(NULL,NULL));
+----
+NULL
+
+query TR rowsort
+select column3, regr_slope(column2, column1)
+from (values (1,2,'a'), (2,4,'a'), (1,3,'b'), (3,9,'b'), (1,10,'c'), 
(NULL,100,'c'))
+group by column3;
+----
+a 2
+b 3
+c NULL
+
+
+
+# regr_slope() testing merge_batch() from RegrSlopeAccumulator's internal 
implementation
+statement ok
+set datafusion.execution.batch_size = 1;
+
+query R
+select regr_slope(c12, c11) from aggregate_test_100;
+----
+0.051534002628
+
+statement ok
+set datafusion.execution.batch_size = 2;
+
+query R
+select regr_slope(c12, c11) from aggregate_test_100;
+----
+0.051534002628
+
+statement ok
+set datafusion.execution.batch_size = 3;
+
+query R
+select regr_slope(c12, c11) from aggregate_test_100;
+----
+0.051534002628
+
+statement ok
+set datafusion.execution.batch_size = 8192;
+
+
+
+# regr_slope testing retract_batch() from RegrSlopeAccumulator's internal 
implementation
+query R
+select regr_slope(column2, column1)
+over (order by column1 rows between 2 preceding and current row)
+from (values (1,2), (2,4), (3,6), (4,12), (5,15), (6, 18));
+----
+NULL
+2
+2
+4
+4.5
+3
+
+query R
+select regr_slope(column2, column1)
+over (order by column1 rows between 2 preceding and current row)
+from (values (1,2), (2,4), (3,6), (3, NULL), (4, NULL), (5,15), (6,18), (7, 
21));
+----
+NULL
+2
+2
+2
+NULL
+NULL
+3
+3
diff --git a/datafusion/expr/src/aggregate_function.rs 
b/datafusion/expr/src/aggregate_function.rs
index add1262237..ac0ac3079e 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -61,6 +61,8 @@ pub enum AggregateFunction {
     CovariancePop,
     /// Correlation
     Correlation,
+    /// Slope from linear regression
+    RegrSlope,
     /// Approximate continuous percentile function
     ApproxPercentileCont,
     /// Approximate continuous percentile function with weight
@@ -102,6 +104,7 @@ impl AggregateFunction {
             Covariance => "COVARIANCE",
             CovariancePop => "COVARIANCE_POP",
             Correlation => "CORRELATION",
+            RegrSlope => "REGR_SLOPE",
             ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
             ApproxPercentileContWithWeight => 
"APPROX_PERCENTILE_CONT_WITH_WEIGHT",
             ApproxMedian => "APPROX_MEDIAN",
@@ -152,6 +155,7 @@ impl FromStr for AggregateFunction {
             "var" => AggregateFunction::Variance,
             "var_pop" => AggregateFunction::VariancePop,
             "var_samp" => AggregateFunction::Variance,
+            "regr_slope" => AggregateFunction::RegrSlope,
             // approximate
             "approx_distinct" => AggregateFunction::ApproxDistinct,
             "approx_median" => AggregateFunction::ApproxMedian,
@@ -228,6 +232,7 @@ impl AggregateFunction {
             }
             AggregateFunction::Stddev => 
stddev_return_type(&coerced_data_types[0]),
             AggregateFunction::StddevPop => 
stddev_return_type(&coerced_data_types[0]),
+            AggregateFunction::RegrSlope => Ok(DataType::Float64),
             AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
             AggregateFunction::ArrayAgg => 
Ok(DataType::List(Arc::new(Field::new(
                 "item",
@@ -311,10 +316,10 @@ impl AggregateFunction {
             | AggregateFunction::LastValue => {
                 Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
             }
-            AggregateFunction::Covariance | AggregateFunction::CovariancePop 
=> {
-                Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
-            }
-            AggregateFunction::Correlation => {
+            AggregateFunction::Covariance
+            | AggregateFunction::CovariancePop
+            | AggregateFunction::Correlation
+            | AggregateFunction::RegrSlope => {
                 Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
             }
             AggregateFunction::ApproxPercentileCont => {
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs 
b/datafusion/expr/src/type_coercion/aggregates.rs
index dec2eb7f12..95ca6ab718 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -148,7 +148,7 @@ pub fn coerce_types(
             }
             Ok(input_types.to_vec())
         }
-        AggregateFunction::Variance => {
+        AggregateFunction::Variance | AggregateFunction::VariancePop => {
             if !is_variance_support_arg_type(&input_types[0]) {
                 return Err(DataFusionError::Plan(format!(
                     "The function {:?} does not support inputs of type {:?}.",
@@ -157,16 +157,7 @@ pub fn coerce_types(
             }
             Ok(input_types.to_vec())
         }
-        AggregateFunction::VariancePop => {
-            if !is_variance_support_arg_type(&input_types[0]) {
-                return Err(DataFusionError::Plan(format!(
-                    "The function {:?} does not support inputs of type {:?}.",
-                    agg_fun, input_types[0]
-                )));
-            }
-            Ok(input_types.to_vec())
-        }
-        AggregateFunction::Covariance => {
+        AggregateFunction::Covariance | AggregateFunction::CovariancePop => {
             if !is_covariance_support_arg_type(&input_types[0]) {
                 return Err(DataFusionError::Plan(format!(
                     "The function {:?} does not support inputs of type {:?}.",
@@ -175,16 +166,7 @@ pub fn coerce_types(
             }
             Ok(input_types.to_vec())
         }
-        AggregateFunction::CovariancePop => {
-            if !is_covariance_support_arg_type(&input_types[0]) {
-                return Err(DataFusionError::Plan(format!(
-                    "The function {:?} does not support inputs of type {:?}.",
-                    agg_fun, input_types[0]
-                )));
-            }
-            Ok(input_types.to_vec())
-        }
-        AggregateFunction::Stddev => {
+        AggregateFunction::Stddev | AggregateFunction::StddevPop => {
             if !is_stddev_support_arg_type(&input_types[0]) {
                 return Err(DataFusionError::Plan(format!(
                     "The function {:?} does not support inputs of type {:?}.",
@@ -193,8 +175,8 @@ pub fn coerce_types(
             }
             Ok(input_types.to_vec())
         }
-        AggregateFunction::StddevPop => {
-            if !is_stddev_support_arg_type(&input_types[0]) {
+        AggregateFunction::Correlation => {
+            if !is_correlation_support_arg_type(&input_types[0]) {
                 return Err(DataFusionError::Plan(format!(
                     "The function {:?} does not support inputs of type {:?}.",
                     agg_fun, input_types[0]
@@ -202,8 +184,11 @@ pub fn coerce_types(
             }
             Ok(input_types.to_vec())
         }
-        AggregateFunction::Correlation => {
-            if !is_correlation_support_arg_type(&input_types[0]) {
+        AggregateFunction::RegrSlope => {
+            let valid_types = [NUMERICS.to_vec(), 
vec![DataType::Null]].concat();
+            let input_types_valid = // number of input already checked before
+                valid_types.contains(&input_types[0]) && 
valid_types.contains(&input_types[1]);
+            if !input_types_valid {
                 return Err(DataFusionError::Plan(format!(
                     "The function {:?} does not support inputs of type {:?}.",
                     agg_fun, input_types[0]
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs 
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 4dc7c824ed..45c98b0187 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -248,6 +248,17 @@ pub fn create_aggregate_expr(
                 "CORR(DISTINCT) aggregations are not available".to_string(),
             ));
         }
+        (AggregateFunction::RegrSlope, false) => 
Arc::new(expressions::RegrSlope::new(
+            input_phy_exprs[0].clone(),
+            input_phy_exprs[1].clone(),
+            name,
+            rt_type,
+        )),
+        (AggregateFunction::RegrSlope, true) => {
+            return Err(DataFusionError::NotImplemented(
+                "REGR_SLOPE(DISTINCT) aggregations are not 
available".to_string(),
+            ));
+        }
         (AggregateFunction::ApproxPercentileCont, false) => {
             if input_phy_exprs.len() == 2 {
                 Arc::new(expressions::ApproxPercentileCont::new(
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs 
b/datafusion/physical-expr/src/aggregate/mod.rs
index 5490b87576..0d0abca062 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -49,6 +49,7 @@ pub mod build_in;
 pub(crate) mod groups_accumulator;
 mod hyperloglog;
 pub mod moving_min_max;
+pub(crate) mod regr_slope;
 pub(crate) mod stats;
 pub(crate) mod stddev;
 pub(crate) mod sum;
diff --git a/datafusion/physical-expr/src/aggregate/regr_slope.rs 
b/datafusion/physical-expr/src/aggregate/regr_slope.rs
new file mode 100644
index 0000000000..fce9627b04
--- /dev/null
+++ b/datafusion/physical-expr/src/aggregate/regr_slope.rs
@@ -0,0 +1,331 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines physical expressions that can evaluated at runtime during query 
execution
+
+use std::any::Any;
+use std::sync::Arc;
+
+use crate::{AggregateExpr, PhysicalExpr};
+use arrow::array::Float64Array;
+use arrow::{
+    array::{ArrayRef, UInt64Array},
+    compute::cast,
+    datatypes::DataType,
+    datatypes::Field,
+};
+use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue};
+use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::Accumulator;
+
+use crate::aggregate::utils::down_cast_any_ref;
+use crate::expressions::format_state_name;
+
+/// regr_slope aggregate expression
+/// Returns the slope of the linear regression line for non-null pairs in 
aggregate columns
+/// Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = 
k*X + b) using minimal
+/// RSS fitting.
+#[derive(Debug)]
+pub struct RegrSlope {
+    name: String,
+    expr_y: Arc<dyn PhysicalExpr>,
+    expr_x: Arc<dyn PhysicalExpr>,
+}
+
+impl RegrSlope {
+    pub fn new(
+        expr_y: Arc<dyn PhysicalExpr>,
+        expr_x: Arc<dyn PhysicalExpr>,
+        name: impl Into<String>,
+        return_type: DataType,
+    ) -> Self {
+        // the result of regr_slope only support FLOAT64 data type.
+        assert!(matches!(return_type, DataType::Float64));
+        Self {
+            name: name.into(),
+            expr_y,
+            expr_x,
+        }
+    }
+}
+
+impl AggregateExpr for RegrSlope {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn field(&self) -> Result<Field> {
+        Ok(Field::new(&self.name, DataType::Float64, true))
+    }
+
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(RegrSlopeAccumulator::try_new()?))
+    }
+
+    fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(RegrSlopeAccumulator::try_new()?))
+    }
+
+    fn state_fields(&self) -> Result<Vec<Field>> {
+        Ok(vec![
+            Field::new(
+                format_state_name(&self.name, "count"),
+                DataType::UInt64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "mean_x"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "mean_y"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "m2_x"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "algo_const"),
+                DataType::Float64,
+                true,
+            ),
+        ])
+    }
+
+    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![self.expr_y.clone(), self.expr_x.clone()]
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+}
+
+impl PartialEq<dyn Any> for RegrSlope {
+    fn eq(&self, other: &dyn Any) -> bool {
+        down_cast_any_ref(other)
+            .downcast_ref::<Self>()
+            .map(|x| {
+                self.name == x.name
+                    && self.expr_y.eq(&x.expr_y)
+                    && self.expr_x.eq(&x.expr_x)
+            })
+            .unwrap_or(false)
+    }
+}
+
+// regr_slope(y, x) is calculated using cov_pop(x, y)/var_pop(x)
+// Reference of online algorithms for calculationg variance:
+// 
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
+#[derive(Debug)]
+pub struct RegrSlopeAccumulator {
+    count: u64,
+    mean_x: f64,
+    mean_y: f64,
+    m2_x: f64,
+    algo_const: f64,
+}
+
+impl RegrSlopeAccumulator {
+    /// Creates a new `RegrSlopeAccumulator`
+    pub fn try_new() -> Result<Self> {
+        Ok(Self {
+            count: 0_u64,
+            mean_x: 0_f64,
+            mean_y: 0_f64,
+            m2_x: 0_f64,
+            algo_const: 0_f64,
+        })
+    }
+}
+
+impl Accumulator for RegrSlopeAccumulator {
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        Ok(vec![
+            ScalarValue::from(self.count),
+            ScalarValue::from(self.mean_x),
+            ScalarValue::from(self.mean_y),
+            ScalarValue::from(self.m2_x),
+            ScalarValue::from(self.algo_const),
+        ])
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        // regr_slope(Y, X) calculates k in y = k*x + b
+        let values_y = &cast(&values[0], &DataType::Float64)?;
+        let values_x = &cast(&values[1], &DataType::Float64)?;
+
+        let mut arr_y = downcast_value!(values_y, 
Float64Array).iter().flatten();
+        let mut arr_x = downcast_value!(values_x, 
Float64Array).iter().flatten();
+
+        for i in 0..values_y.len() {
+            // skip either x or y is NULL
+            let value_y = if values_y.is_valid(i) {
+                arr_y.next()
+            } else {
+                None
+            };
+            let value_x = if values_x.is_valid(i) {
+                arr_x.next()
+            } else {
+                None
+            };
+            if value_y.is_none() || value_x.is_none() {
+                continue;
+            }
+
+            // Update states for regr_slope(y,x) [using 
cov_pop(x,y)/var_pop(x)]
+            let value_y = unwrap_or_internal_err!(value_y);
+            let value_x = unwrap_or_internal_err!(value_x);
+
+            self.count += 1;
+            let delta_x = value_x - self.mean_x;
+            let delta_y = value_y - self.mean_y;
+            self.mean_x += delta_x / self.count as f64;
+            let delta_x_2 = value_x - self.mean_x;
+            self.m2_x += delta_x * delta_x_2;
+            self.mean_y += delta_y / self.count as f64;
+            self.algo_const += delta_x * (value_y - self.mean_y);
+        }
+
+        Ok(())
+    }
+
+    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let values_y = &cast(&values[0], &DataType::Float64)?;
+        let values_x = &cast(&values[1], &DataType::Float64)?;
+
+        let mut arr_y = downcast_value!(values_y, 
Float64Array).iter().flatten();
+        let mut arr_x = downcast_value!(values_x, 
Float64Array).iter().flatten();
+
+        for i in 0..values_y.len() {
+            // skip either x or y is NULL
+            let value_y = if values_y.is_valid(i) {
+                arr_y.next()
+            } else {
+                None
+            };
+            let value_x = if values_x.is_valid(i) {
+                arr_x.next()
+            } else {
+                None
+            };
+            if value_y.is_none() || value_x.is_none() {
+                continue;
+            }
+
+            // Update states for regr_slope(y,x) [using 
cov_pop(x,y)/var_pop(x)]
+            let value_y = unwrap_or_internal_err!(value_y);
+            let value_x = unwrap_or_internal_err!(value_x);
+
+            if self.count > 1 {
+                self.count -= 1;
+                let delta_x = value_x - self.mean_x;
+                let delta_y = value_y - self.mean_y;
+                self.mean_x -= delta_x / self.count as f64;
+                let delta_x_2 = value_x - self.mean_x;
+                self.m2_x -= delta_x * delta_x_2;
+                self.mean_y -= delta_y / self.count as f64;
+                self.algo_const -= delta_x * (value_y - self.mean_y);
+            } else {
+                self.count = 0;
+                self.mean_x = 0.0;
+                self.m2_x = 0.0;
+                self.mean_y = 0.0;
+                self.algo_const = 0.0;
+            }
+        }
+
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        let count_arr = downcast_value!(states[0], UInt64Array);
+        let mean_x_arr = downcast_value!(states[1], Float64Array);
+        let mean_y_arr = downcast_value!(states[2], Float64Array);
+        let m2_x_arr = downcast_value!(states[3], Float64Array);
+        let algo_const_arr = downcast_value!(states[4], Float64Array);
+
+        for i in 0..count_arr.len() {
+            let count_b = count_arr.value(i);
+            if count_b == 0_u64 {
+                continue;
+            }
+            let (count_a, mean_x_a, mean_y_a, m2_x_a, algo_const_a) = (
+                self.count,
+                self.mean_x,
+                self.mean_y,
+                self.m2_x,
+                self.algo_const,
+            );
+            let (count_b, mean_x_b, mean_y_b, m2_x_b, algo_const_b) = (
+                count_b,
+                mean_x_arr.value(i),
+                mean_y_arr.value(i),
+                m2_x_arr.value(i),
+                algo_const_arr.value(i),
+            );
+
+            // Assuming two different batches of input have calculated the 
states:
+            // batch A of Y, X -> {count_a, mean_x_a, mean_y_a, m2_x_a, 
algo_const_a}
+            // batch B of Y, X -> {count_b, mean_x_b, mean_y_b, m2_x_b, 
algo_const_b}
+            // The merged states from A and B are {count_ab, mean_x_ab, 
mean_y_ab, m2_x_ab,
+            // algo_const_ab}
+            //
+            // Reference for the algorithm to merge states:
+            // 
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
+            let count_ab = count_a + count_b;
+            let (count_a, count_b) = (count_a as f64, count_b as f64);
+            let d_x = mean_x_b - mean_x_a;
+            let d_y = mean_y_b - mean_y_a;
+            let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
+            let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
+            let m2_x_ab =
+                m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as 
f64;
+            let algo_const_ab = algo_const_a
+                + algo_const_b
+                + d_x * d_y * count_a * count_b / count_ab as f64;
+
+            self.count = count_ab;
+            self.mean_x = mean_x_ab;
+            self.mean_y = mean_y_ab;
+            self.m2_x = m2_x_ab;
+            self.algo_const = algo_const_ab;
+        }
+        Ok(())
+    }
+
+    fn evaluate(&self) -> Result<ScalarValue> {
+        let cov_pop_x_y = self.algo_const / self.count as f64;
+        let var_pop_x = self.m2_x / self.count as f64;
+
+        // Only 0/1 point or slope is infinite
+        if self.count <= 1 || var_pop_x == 0.0 {
+            Ok(ScalarValue::Float64(None))
+        } else {
+            Ok(ScalarValue::Float64(Some(cov_pop_x_y / var_pop_x)))
+        }
+    }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self)
+    }
+}
diff --git a/datafusion/physical-expr/src/expressions/mod.rs 
b/datafusion/physical-expr/src/expressions/mod.rs
index c660cfadcc..c56c63db7b 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -60,6 +60,7 @@ pub use crate::aggregate::grouping::Grouping;
 pub use crate::aggregate::median::Median;
 pub use crate::aggregate::min_max::{Max, Min};
 pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator};
+pub use crate::aggregate::regr_slope::RegrSlope;
 pub use crate::aggregate::stats::StatsType;
 pub use crate::aggregate::stddev::{Stddev, StddevPop};
 pub use crate::aggregate::sum::Sum;
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index e9ae76b25d..9694a5beb7 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -613,6 +613,7 @@ enum AggregateFunction {
   // we append "_AGG" to obey name scoping rules.
   FIRST_VALUE_AGG = 24;
   LAST_VALUE_AGG = 25;
+  REGR_SLOPE = 26;
 }
 
 message AggregateExprNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index a5d85cc6cf..40f58b312a 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -465,6 +465,7 @@ impl serde::Serialize for AggregateFunction {
             Self::BoolOr => "BOOL_OR",
             Self::FirstValueAgg => "FIRST_VALUE_AGG",
             Self::LastValueAgg => "LAST_VALUE_AGG",
+            Self::RegrSlope => "REGR_SLOPE",
         };
         serializer.serialize_str(variant)
     }
@@ -502,6 +503,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
             "BOOL_OR",
             "FIRST_VALUE_AGG",
             "LAST_VALUE_AGG",
+            "REGR_SLOPE",
         ];
 
         struct GeneratedVisitor;
@@ -570,6 +572,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
                     "BOOL_OR" => Ok(AggregateFunction::BoolOr),
                     "FIRST_VALUE_AGG" => Ok(AggregateFunction::FirstValueAgg),
                     "LAST_VALUE_AGG" => Ok(AggregateFunction::LastValueAgg),
+                    "REGR_SLOPE" => Ok(AggregateFunction::RegrSlope),
                     _ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
                 }
             }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index c6f3a23ed6..7e4a5f8afd 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2578,6 +2578,7 @@ pub enum AggregateFunction {
     /// we append "_AGG" to obey name scoping rules.
     FirstValueAgg = 24,
     LastValueAgg = 25,
+    RegrSlope = 26,
 }
 impl AggregateFunction {
     /// String value of the enum field names used in the ProtoBuf definition.
@@ -2614,6 +2615,7 @@ impl AggregateFunction {
             AggregateFunction::BoolOr => "BOOL_OR",
             AggregateFunction::FirstValueAgg => "FIRST_VALUE_AGG",
             AggregateFunction::LastValueAgg => "LAST_VALUE_AGG",
+            AggregateFunction::RegrSlope => "REGR_SLOPE",
         }
     }
     /// Creates an enum from field names used in the ProtoBuf definition.
@@ -2647,6 +2649,7 @@ impl AggregateFunction {
             "BOOL_OR" => Some(Self::BoolOr),
             "FIRST_VALUE_AGG" => Some(Self::FirstValueAgg),
             "LAST_VALUE_AGG" => Some(Self::LastValueAgg),
+            "REGR_SLOPE" => Some(Self::RegrSlope),
             _ => None,
         }
     }
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index 1464f32bb3..4caff5fba0 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -549,6 +549,7 @@ impl From<protobuf::AggregateFunction> for 
AggregateFunction {
             protobuf::AggregateFunction::Stddev => Self::Stddev,
             protobuf::AggregateFunction::StddevPop => Self::StddevPop,
             protobuf::AggregateFunction::Correlation => Self::Correlation,
+            protobuf::AggregateFunction::RegrSlope => Self::RegrSlope,
             protobuf::AggregateFunction::ApproxPercentileCont => {
                 Self::ApproxPercentileCont
             }
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index df5701a282..3f4fdfeb74 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -384,6 +384,7 @@ impl From<&AggregateFunction> for 
protobuf::AggregateFunction {
             AggregateFunction::Stddev => Self::Stddev,
             AggregateFunction::StddevPop => Self::StddevPop,
             AggregateFunction::Correlation => Self::Correlation,
+            AggregateFunction::RegrSlope => Self::RegrSlope,
             AggregateFunction::ApproxPercentileCont => 
Self::ApproxPercentileCont,
             AggregateFunction::ApproxPercentileContWithWeight => {
                 Self::ApproxPercentileContWithWeight
@@ -675,6 +676,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
                     AggregateFunction::Correlation => {
                         protobuf::AggregateFunction::Correlation
                     }
+                    AggregateFunction::RegrSlope => {
+                        protobuf::AggregateFunction::RegrSlope
+                    }
                     AggregateFunction::ApproxMedian => {
                         protobuf::AggregateFunction::ApproxMedian
                     }
diff --git a/docs/source/user-guide/sql/aggregate_functions.md 
b/docs/source/user-guide/sql/aggregate_functions.md
index 132ba47e24..71168b0622 100644
--- a/docs/source/user-guide/sql/aggregate_functions.md
+++ b/docs/source/user-guide/sql/aggregate_functions.md
@@ -245,6 +245,7 @@ last_value(expression [ORDER BY expression])
 - [var](#var)
 - [var_pop](#var_pop)
 - [var_samp](#var_samp)
+- [regr_slope](#regr_slope)
 
 ### `corr`
 
@@ -384,6 +385,22 @@ var_samp(expression)
 - **expression**: Expression to operate on.
   Can be a constant, column, or function, and any combination of arithmetic 
operators.
 
+### `regr_slope`
+
+Returns the slope of the linear regression line for non-null pairs in 
aggregate columns.
+Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X 
+ b) using minimal RSS fitting.
+
+```
+regr_slope(expression1, expression2)
+```
+
+#### Arguments
+
+- **expression1**: Expression to operate on.
+  Can be a constant, column, or function, and any combination of arithmetic 
operators.
+- **expression2**: Expression to operate on.
+  Can be a constant, column, or function, and any combination of arithmetic 
operators.
+
 ## Approximate
 
 - [approx_distinct](#approx_distinct)

Reply via email to