This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 5ec45dd  add regr_* functions (#499)
5ec45dd is described below

commit 5ec45ddd5f3b44a3b39d591e3aa3c6eb8880c5ee
Author: zhenxing jiang <[email protected]>
AuthorDate: Sun Oct 15 11:05:57 2023 -0500

    add regr_* functions (#499)
    
    Co-authored-by: Andy Grove <[email protected]>
---
 datafusion/tests/test_functions.py | 22 ++++++++++++++++++++++
 src/functions.rs                   | 18 ++++++++++++++++++
 2 files changed, 40 insertions(+)

diff --git a/datafusion/tests/test_functions.py 
b/datafusion/tests/test_functions.py
index e504cc4..be2a2f1 100644
--- a/datafusion/tests/test_functions.py
+++ b/datafusion/tests/test_functions.py
@@ -479,6 +479,28 @@ def test_case(df):
     assert result.column(2) == pa.array(["Hola", "Mundo", None])
 
 
+def test_regr_funcs(df):
+    # test case base on
+    # 
https://github.com/apache/arrow-datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2330
+    ctx = SessionContext()
+    result = ctx.sql(
+        "select regr_slope(1,1), regr_intercept(1,1), "
+        "regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), "
+        "regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), "
+        "regr_sxy(1,1);"
+    ).collect()
+
+    assert result[0].column(0) == pa.array([None], type=pa.float64())
+    assert result[0].column(1) == pa.array([None], type=pa.float64())
+    assert result[0].column(2) == pa.array([1], type=pa.float64())
+    assert result[0].column(3) == pa.array([None], type=pa.float64())
+    assert result[0].column(4) == pa.array([1], type=pa.float64())
+    assert result[0].column(5) == pa.array([1], type=pa.float64())
+    assert result[0].column(6) == pa.array([0], type=pa.float64())
+    assert result[0].column(7) == pa.array([0], type=pa.float64())
+    assert result[0].column(8) == pa.array([0], type=pa.float64())
+
+
 def test_first_last_value(df):
     df = df.aggregate(
         [],
diff --git a/src/functions.rs b/src/functions.rs
index 2f2f34e..e509aff 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -362,6 +362,15 @@ aggregate_function!(stddev_samp, Stddev);
 aggregate_function!(var, Variance);
 aggregate_function!(var_pop, VariancePop);
 aggregate_function!(var_samp, Variance);
+aggregate_function!(regr_avgx, RegrAvgx);
+aggregate_function!(regr_avgy, RegrAvgy);
+aggregate_function!(regr_count, RegrCount);
+aggregate_function!(regr_intercept, RegrIntercept);
+aggregate_function!(regr_r2, RegrR2);
+aggregate_function!(regr_slope, RegrSlope);
+aggregate_function!(regr_sxx, RegrSXX);
+aggregate_function!(regr_sxy, RegrSXY);
+aggregate_function!(regr_syy, RegrSYY);
 aggregate_function!(first_value, FirstValue);
 aggregate_function!(last_value, LastValue);
 aggregate_function!(bit_and, BitAnd);
@@ -496,6 +505,15 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
     m.add_wrapped(wrap_pyfunction!(var_pop))?;
     m.add_wrapped(wrap_pyfunction!(var_samp))?;
     m.add_wrapped(wrap_pyfunction!(window))?;
+    m.add_wrapped(wrap_pyfunction!(regr_avgx))?;
+    m.add_wrapped(wrap_pyfunction!(regr_avgy))?;
+    m.add_wrapped(wrap_pyfunction!(regr_count))?;
+    m.add_wrapped(wrap_pyfunction!(regr_intercept))?;
+    m.add_wrapped(wrap_pyfunction!(regr_r2))?;
+    m.add_wrapped(wrap_pyfunction!(regr_slope))?;
+    m.add_wrapped(wrap_pyfunction!(regr_sxx))?;
+    m.add_wrapped(wrap_pyfunction!(regr_sxy))?;
+    m.add_wrapped(wrap_pyfunction!(regr_syy))?;
     m.add_wrapped(wrap_pyfunction!(first_value))?;
     m.add_wrapped(wrap_pyfunction!(last_value))?;
     m.add_wrapped(wrap_pyfunction!(bit_and))?;

Reply via email to