This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 5ec45dd add regr_* functions (#499)
5ec45dd is described below
commit 5ec45ddd5f3b44a3b39d591e3aa3c6eb8880c5ee
Author: zhenxing jiang <[email protected]>
AuthorDate: Sun Oct 15 11:05:57 2023 -0500
add regr_* functions (#499)
Co-authored-by: Andy Grove <[email protected]>
---
datafusion/tests/test_functions.py | 22 ++++++++++++++++++++++
src/functions.rs | 18 ++++++++++++++++++
2 files changed, 40 insertions(+)
diff --git a/datafusion/tests/test_functions.py
b/datafusion/tests/test_functions.py
index e504cc4..be2a2f1 100644
--- a/datafusion/tests/test_functions.py
+++ b/datafusion/tests/test_functions.py
@@ -479,6 +479,28 @@ def test_case(df):
assert result.column(2) == pa.array(["Hola", "Mundo", None])
+def test_regr_funcs(df):
+ # test case base on
+ #
https://github.com/apache/arrow-datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2330
+ ctx = SessionContext()
+ result = ctx.sql(
+ "select regr_slope(1,1), regr_intercept(1,1), "
+ "regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), "
+ "regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), "
+ "regr_sxy(1,1);"
+ ).collect()
+
+ assert result[0].column(0) == pa.array([None], type=pa.float64())
+ assert result[0].column(1) == pa.array([None], type=pa.float64())
+ assert result[0].column(2) == pa.array([1], type=pa.float64())
+ assert result[0].column(3) == pa.array([None], type=pa.float64())
+ assert result[0].column(4) == pa.array([1], type=pa.float64())
+ assert result[0].column(5) == pa.array([1], type=pa.float64())
+ assert result[0].column(6) == pa.array([0], type=pa.float64())
+ assert result[0].column(7) == pa.array([0], type=pa.float64())
+ assert result[0].column(8) == pa.array([0], type=pa.float64())
+
+
def test_first_last_value(df):
df = df.aggregate(
[],
diff --git a/src/functions.rs b/src/functions.rs
index 2f2f34e..e509aff 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -362,6 +362,15 @@ aggregate_function!(stddev_samp, Stddev);
aggregate_function!(var, Variance);
aggregate_function!(var_pop, VariancePop);
aggregate_function!(var_samp, Variance);
+aggregate_function!(regr_avgx, RegrAvgx);
+aggregate_function!(regr_avgy, RegrAvgy);
+aggregate_function!(regr_count, RegrCount);
+aggregate_function!(regr_intercept, RegrIntercept);
+aggregate_function!(regr_r2, RegrR2);
+aggregate_function!(regr_slope, RegrSlope);
+aggregate_function!(regr_sxx, RegrSXX);
+aggregate_function!(regr_sxy, RegrSXY);
+aggregate_function!(regr_syy, RegrSYY);
aggregate_function!(first_value, FirstValue);
aggregate_function!(last_value, LastValue);
aggregate_function!(bit_and, BitAnd);
@@ -496,6 +505,15 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(var_pop))?;
m.add_wrapped(wrap_pyfunction!(var_samp))?;
m.add_wrapped(wrap_pyfunction!(window))?;
+ m.add_wrapped(wrap_pyfunction!(regr_avgx))?;
+ m.add_wrapped(wrap_pyfunction!(regr_avgy))?;
+ m.add_wrapped(wrap_pyfunction!(regr_count))?;
+ m.add_wrapped(wrap_pyfunction!(regr_intercept))?;
+ m.add_wrapped(wrap_pyfunction!(regr_r2))?;
+ m.add_wrapped(wrap_pyfunction!(regr_slope))?;
+ m.add_wrapped(wrap_pyfunction!(regr_sxx))?;
+ m.add_wrapped(wrap_pyfunction!(regr_sxy))?;
+ m.add_wrapped(wrap_pyfunction!(regr_syy))?;
m.add_wrapped(wrap_pyfunction!(first_value))?;
m.add_wrapped(wrap_pyfunction!(last_value))?;
m.add_wrapped(wrap_pyfunction!(bit_and))?;