This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2057eb7e203 [SPARK-43933][SQL][PYTHON][CONNECT] Add linear regression aggregate functions to Scala and Python 2057eb7e203 is described below commit 2057eb7e203c9fde3f4fa13d5f04225cf6e49a87 Author: Jiaan Geng <belie...@163.com> AuthorDate: Wed Jun 7 22:56:13 2023 +0800 [SPARK-43933][SQL][PYTHON][CONNECT] Add linear regression aggregate functions to Scala and Python ### What changes were proposed in this pull request? Based HyukjinKwon 's suggestion, this PR want add linear regression aggregate functions to Scala and Python API. These functions show below. - `regr_avgx` - `regr_avgy` - `regr_count` - `regr_intercept` - `regr_r2` - `regr_slope` - `regr_sxx` - `regr_sxy` - `regr_syy` ### Why are the changes needed? Add linear regression aggregate functions to Scala and Python API ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New test cases. Closes #41474 from beliefer/SPARK-43933. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../scala/org/apache/spark/sql/functions.scala | 82 ++++++ .../apache/spark/sql/PlanGenerationTestSuite.scala | 36 +++ .../explain-results/function_regr_avgx.explain | 2 + .../explain-results/function_regr_avgy.explain | 2 + .../explain-results/function_regr_count.explain | 2 + .../function_regr_intercept.explain | 2 + .../explain-results/function_regr_r2.explain | 2 + .../explain-results/function_regr_slope.explain | 2 + .../explain-results/function_regr_sxx.explain | 2 + .../explain-results/function_regr_sxy.explain | 2 + .../explain-results/function_regr_syy.explain | 2 + .../query-tests/queries/function_regr_avgx.json | 29 +++ .../queries/function_regr_avgx.proto.bin | Bin 0 -> 185 bytes .../query-tests/queries/function_regr_avgy.json | 29 +++ .../queries/function_regr_avgy.proto.bin | Bin 0 -> 185 bytes .../query-tests/queries/function_regr_count.json | 29 +++ .../queries/function_regr_count.proto.bin | Bin 0 -> 186 bytes .../queries/function_regr_intercept.json | 29 +++ .../queries/function_regr_intercept.proto.bin | Bin 0 -> 190 bytes .../query-tests/queries/function_regr_r2.json | 29 +++ .../query-tests/queries/function_regr_r2.proto.bin | Bin 0 -> 183 bytes .../query-tests/queries/function_regr_slope.json | 29 +++ .../queries/function_regr_slope.proto.bin | Bin 0 -> 186 bytes .../query-tests/queries/function_regr_sxx.json | 29 +++ .../queries/function_regr_sxx.proto.bin | Bin 0 -> 184 bytes .../query-tests/queries/function_regr_sxy.json | 29 +++ .../queries/function_regr_sxy.proto.bin | Bin 0 -> 184 bytes .../query-tests/queries/function_regr_syy.json | 29 +++ .../queries/function_regr_syy.proto.bin | Bin 0 -> 184 bytes .../source/reference/pyspark.sql/functions.rst | 9 + python/pyspark/sql/connect/functions.py | 63 +++++ python/pyspark/sql/functions.py | 280 +++++++++++++++++++++ .../scala/org/apache/spark/sql/functions.scala | 83 ++++++ .../apache/spark/sql/DataFrameAggregateSuite.scala | 13 + 34 files changed, 845 insertions(+) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 60634388fa1..9179b88a26d 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -992,6 +992,88 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) + /** + * Aggregate function: returns the average of the independent variable for non-null pairs in a + * group, where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_avgx(y: Column, x: Column): Column = Column.fn("regr_avgx", y, x) + + /** + * Aggregate function: returns the average of the independent variable for non-null pairs in a + * group, where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_avgy(y: Column, x: Column): Column = Column.fn("regr_avgy", y, x) + + /** + * Aggregate function: returns the number of non-null number pairs in a group, where `y` is the + * dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_count(y: Column, x: Column): Column = Column.fn("regr_count", y, x) + + /** + * Aggregate function: returns the intercept of the univariate linear regression line for + * non-null pairs in a group, where `y` is the dependent variable and `x` is the independent + * variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_intercept(y: Column, x: Column): Column = Column.fn("regr_intercept", y, x) + + /** + * Aggregate function: returns the coefficient of determination for non-null pairs in a group, + * where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_r2(y: Column, x: Column): Column = Column.fn("regr_r2", y, x) + + /** + * Aggregate function: returns the slope of the linear regression line for non-null pairs in a + * group, where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_slope(y: Column, x: Column): Column = Column.fn("regr_slope", y, x) + + /** + * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(x) for non-null pairs in a group, + * where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_sxx(y: Column, x: Column): Column = Column.fn("regr_sxx", y, x) + + /** + * Aggregate function: returns REGR_COUNT(y, x) * COVAR_POP(y, x) for non-null pairs in a group, + * where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_sxy(y: Column, x: Column): Column = Column.fn("regr_sxy", y, x) + + /** + * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(y) for non-null pairs in a group, + * where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_syy(y: Column, x: Column): Column = Column.fn("regr_syy", y, x) + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index b104807d8a5..e652594ab1f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -1030,6 +1030,42 @@ class PlanGenerationTestSuite fn.var_pop("a") } + functionTest("regr_avgx") { + fn.regr_avgx(fn.col("a"), fn.col("b")) + } + + functionTest("regr_avgy") { + fn.regr_avgy(fn.col("a"), fn.col("b")) + } + + functionTest("regr_count") { + fn.regr_count(fn.col("a"), fn.col("b")) + } + + functionTest("regr_intercept") { + fn.regr_intercept(fn.col("a"), fn.col("b")) + } + + functionTest("regr_r2") { + fn.regr_r2(fn.col("a"), fn.col("b")) + } + + functionTest("regr_slope") { + fn.regr_slope(fn.col("a"), fn.col("b")) + } + + functionTest("regr_sxx") { + fn.regr_sxx(fn.col("a"), fn.col("b")) + } + + functionTest("regr_sxy") { + fn.regr_sxy(fn.col("a"), fn.col("b")) + } + + functionTest("regr_syy") { + fn.regr_syy(fn.col("a"), fn.col("b")) + } + functionTest("array") { fn.array("a", "a") } diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_avgx.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_avgx.explain new file mode 100644 index 00000000000..3f6fc469208 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_avgx.explain @@ -0,0 +1,2 @@ +Aggregate [avg(if ((isnotnull(a#0) AND isnotnull(b#0))) b#0 else null) AS regr_avgx(a, b)#0] ++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_avgy.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_avgy.explain new file mode 100644 index 00000000000..5c2f8b60849 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_avgy.explain @@ -0,0 +1,2 @@ +Aggregate [avg(if ((isnotnull(a#0) AND isnotnull(b#0))) a#0 else null) AS regr_avgy(a, b)#0] ++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_count.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_count.explain new file mode 100644 index 00000000000..b0b57d1dff3 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_count.explain @@ -0,0 +1,2 @@ +Aggregate [count(a#0, b#0) AS regr_count(a, b)#0L] ++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_intercept.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_intercept.explain new file mode 100644 index 00000000000..d3bc25218c6 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_intercept.explain @@ -0,0 +1,2 @@ +Aggregate [regr_intercept(cast(a#0 as double), b#0) AS regr_intercept(a, b)#0] ++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_r2.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_r2.explain new file mode 100644 index 00000000000..d997d7d70f3 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_r2.explain @@ -0,0 +1,2 @@ +Aggregate [regr_r2(cast(a#0 as double), b#0) AS regr_r2(a, b)#0] ++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_slope.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_slope.explain new file mode 100644 index 00000000000..f3b5f5074db --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_slope.explain @@ -0,0 +1,2 @@ +Aggregate [regr_slope(cast(a#0 as double), b#0) AS regr_slope(a, b)#0] ++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_sxx.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_sxx.explain new file mode 100644 index 00000000000..fee587e1360 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_sxx.explain @@ -0,0 +1,2 @@ +Aggregate [regrreplacement(if ((isnull(cast(a#0 as double)) OR isnull(b#0))) null else b#0) AS regr_sxx(a, b)#0] ++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_sxy.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_sxy.explain new file mode 100644 index 00000000000..43bebd8bbf2 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_sxy.explain @@ -0,0 +1,2 @@ +Aggregate [regr_sxy(cast(a#0 as double), b#0) AS regr_sxy(a, b)#0] ++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_syy.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_syy.explain new file mode 100644 index 00000000000..fd06d1e2e30 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regr_syy.explain @@ -0,0 +1,2 @@ +Aggregate [regrreplacement(if ((isnull(cast(a#0 as double)) OR isnull(b#0))) null else cast(a#0 as double)) AS regr_syy(a, b)#0] ++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_avgx.json b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_avgx.json new file mode 100644 index 00000000000..4fdc9b035d7 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_avgx.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "regr_avgx", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_avgx.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_avgx.proto.bin new file mode 100644 index 00000000000..5771d141728 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_avgx.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_avgy.json b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_avgy.json new file mode 100644 index 00000000000..af225fdf5a8 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_avgy.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "regr_avgy", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_avgy.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_avgy.proto.bin new file mode 100644 index 00000000000..0a6dcf0106a Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_avgy.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_count.json b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_count.json new file mode 100644 index 00000000000..510fc78140a --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_count.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "regr_count", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_count.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_count.proto.bin new file mode 100644 index 00000000000..b1eff9f4d03 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_count.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_intercept.json b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_intercept.json new file mode 100644 index 00000000000..a8596615a2d --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_intercept.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "regr_intercept", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_intercept.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_intercept.proto.bin new file mode 100644 index 00000000000..b9a1c0eff89 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_intercept.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_r2.json b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_r2.json new file mode 100644 index 00000000000..9f88c6ad412 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_r2.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "regr_r2", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_r2.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_r2.proto.bin new file mode 100644 index 00000000000..0011348d388 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_r2.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_slope.json b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_slope.json new file mode 100644 index 00000000000..9503b2c6fef --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_slope.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "regr_slope", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_slope.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_slope.proto.bin new file mode 100644 index 00000000000..69c918a7861 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_slope.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_sxx.json b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_sxx.json new file mode 100644 index 00000000000..fb243c9989e --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_sxx.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "regr_sxx", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_sxx.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_sxx.proto.bin new file mode 100644 index 00000000000..df31a2e6851 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_sxx.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_sxy.json b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_sxy.json new file mode 100644 index 00000000000..459deaa391e --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_sxy.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "regr_sxy", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_sxy.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_sxy.proto.bin new file mode 100644 index 00000000000..db51c0bc32a Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_sxy.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_syy.json b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_syy.json new file mode 100644 index 00000000000..877fbc3aa7c --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_syy.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "regr_syy", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_regr_syy.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_syy.proto.bin new file mode 100644 index 00000000000..6452b277a6e Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_regr_syy.proto.bin differ diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index e1317bb9228..32550763268 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -258,6 +258,15 @@ Aggregate Functions mode percentile_approx product + regr_avgx + regr_avgy + regr_count + regr_intercept + regr_r2 + regr_slope + regr_sxx + regr_sxy + regr_syy skewness stddev stddev_pop diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index c986c34b8d6..21f2fc2576d 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -1051,6 +1051,69 @@ def var_pop(col: "ColumnOrName") -> Column: var_pop.__doc__ = pysparkfuncs.var_pop.__doc__ +def regr_avgx(y: "ColumnOrName", x: "ColumnOrName") -> Column: + return _invoke_function_over_columns("regr_avgx", y, x) + + +regr_avgx.__doc__ = pysparkfuncs.regr_avgx.__doc__ + + +def regr_avgy(y: "ColumnOrName", x: "ColumnOrName") -> Column: + return _invoke_function_over_columns("regr_avgy", y, x) + + +regr_avgy.__doc__ = pysparkfuncs.regr_avgy.__doc__ + + +def regr_count(y: "ColumnOrName", x: "ColumnOrName") -> Column: + return _invoke_function_over_columns("regr_count", y, x) + + +regr_count.__doc__ = pysparkfuncs.regr_count.__doc__ + + +def regr_intercept(y: "ColumnOrName", x: "ColumnOrName") -> Column: + return _invoke_function_over_columns("regr_intercept", y, x) + + +regr_intercept.__doc__ = pysparkfuncs.regr_intercept.__doc__ + + +def regr_r2(y: "ColumnOrName", x: "ColumnOrName") -> Column: + return _invoke_function_over_columns("regr_r2", y, x) + + +regr_r2.__doc__ = pysparkfuncs.regr_r2.__doc__ + + +def regr_slope(y: "ColumnOrName", x: "ColumnOrName") -> Column: + return _invoke_function_over_columns("regr_slope", y, x) + + +regr_slope.__doc__ = pysparkfuncs.regr_slope.__doc__ + + +def regr_sxx(y: "ColumnOrName", x: "ColumnOrName") -> Column: + return _invoke_function_over_columns("regr_sxx", y, x) + + +regr_sxx.__doc__ = pysparkfuncs.regr_sxx.__doc__ + + +def regr_sxy(y: "ColumnOrName", x: "ColumnOrName") -> Column: + return _invoke_function_over_columns("regr_sxy", y, x) + + +regr_sxy.__doc__ = pysparkfuncs.regr_sxy.__doc__ + + +def regr_syy(y: "ColumnOrName", x: "ColumnOrName") -> Column: + return _invoke_function_over_columns("regr_syy", y, x) + + +regr_syy.__doc__ = pysparkfuncs.regr_syy.__doc__ + + def var_samp(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("var_samp", col) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9a0c96be70e..93a4056d8b2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2073,6 +2073,286 @@ def var_pop(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("var_pop", col) +@try_remote_functions +def regr_avgx(y: "ColumnOrName", x: "ColumnOrName") -> Column: + """ + Aggregate function: returns the average of the independent variable for non-null pairs + in a group, where `y` is the dependent variable and `x` is the independent variable. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + y : :class:`~pyspark.sql.Column` or str + the dependent variable. + x : :class:`~pyspark.sql.Column` or str + the independent variable. + + Returns + ------- + :class:`~pyspark.sql.Column` + the average of the independent variable for non-null pairs in a group. + + Examples + -------- + >>> x = (col("id") % 3).alias("x") + >>> y = (randn(42) + x * 10).alias("y") + >>> df = spark.range(0, 1000, 1, 1).select(x, y) + >>> df.select(regr_avgx("y", "x")).first() + Row(regr_avgx(y, x)=0.999) + """ + return _invoke_function_over_columns("regr_avgx", y, x) + + +@try_remote_functions +def regr_avgy(y: "ColumnOrName", x: "ColumnOrName") -> Column: + """ + Aggregate function: returns the average of the dependent variable for non-null pairs + in a group, where `y` is the dependent variable and `x` is the independent variable. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + y : :class:`~pyspark.sql.Column` or str + the dependent variable. + x : :class:`~pyspark.sql.Column` or str + the independent variable. + + Returns + ------- + :class:`~pyspark.sql.Column` + the average of the dependent variable for non-null pairs in a group. + + Examples + -------- + >>> x = (col("id") % 3).alias("x") + >>> y = (randn(42) + x * 10).alias("y") + >>> df = spark.range(0, 1000, 1, 1).select(x, y) + >>> df.select(regr_avgy("y", "x")).first() + Row(regr_avgy(y, x)=9.980732994136464) + """ + return _invoke_function_over_columns("regr_avgy", y, x) + + +@try_remote_functions +def regr_count(y: "ColumnOrName", x: "ColumnOrName") -> Column: + """ + Aggregate function: returns the number of non-null number pairs + in a group, where `y` is the dependent variable and `x` is the independent variable. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + y : :class:`~pyspark.sql.Column` or str + the dependent variable. + x : :class:`~pyspark.sql.Column` or str + the independent variable. + + Returns + ------- + :class:`~pyspark.sql.Column` + the number of non-null number pairs in a group. + + Examples + -------- + >>> x = (col("id") % 3).alias("x") + >>> y = (randn(42) + x * 10).alias("y") + >>> df = spark.range(0, 1000, 1, 1).select(x, y) + >>> df.select(regr_count("y", "x")).first() + Row(regr_count(y, x)=1000) + """ + return _invoke_function_over_columns("regr_count", y, x) + + +@try_remote_functions +def regr_intercept(y: "ColumnOrName", x: "ColumnOrName") -> Column: + """ + Aggregate function: returns the intercept of the univariate linear regression line + for non-null pairs in a group, where `y` is the dependent variable and + `x` is the independent variable. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + y : :class:`~pyspark.sql.Column` or str + the dependent variable. + x : :class:`~pyspark.sql.Column` or str + the independent variable. + + Returns + ------- + :class:`~pyspark.sql.Column` + the intercept of the univariate linear regression line for non-null pairs in a group. + + Examples + -------- + >>> x = (col("id") % 3).alias("x") + >>> y = (randn(42) + x * 10).alias("y") + >>> df = spark.range(0, 1000, 1, 1).select(x, y) + >>> df.select(regr_intercept("y", "x")).first() + Row(regr_intercept(y, x)=-0.04961745990969568) + """ + return _invoke_function_over_columns("regr_intercept", y, x) + + +@try_remote_functions +def regr_r2(y: "ColumnOrName", x: "ColumnOrName") -> Column: + """ + Aggregate function: returns the coefficient of determination for non-null pairs + in a group, where `y` is the dependent variable and `x` is the independent variable. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + y : :class:`~pyspark.sql.Column` or str + the dependent variable. + x : :class:`~pyspark.sql.Column` or str + the independent variable. + + Returns + ------- + :class:`~pyspark.sql.Column` + the coefficient of determination for non-null pairs in a group. + + Examples + -------- + >>> x = (col("id") % 3).alias("x") + >>> y = (randn(42) + x * 10).alias("y") + >>> df = spark.range(0, 1000, 1, 1).select(x, y) + >>> df.select(regr_r2("y", "x")).first() + Row(regr_r2(y, x)=0.9851908293645436) + """ + return _invoke_function_over_columns("regr_r2", y, x) + + +@try_remote_functions +def regr_slope(y: "ColumnOrName", x: "ColumnOrName") -> Column: + """ + Aggregate function: returns the slope of the linear regression line for non-null pairs + in a group, where `y` is the dependent variable and `x` is the independent variable. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + y : :class:`~pyspark.sql.Column` or str + the dependent variable. + x : :class:`~pyspark.sql.Column` or str + the independent variable. + + Returns + ------- + :class:`~pyspark.sql.Column` + the slope of the linear regression line for non-null pairs in a group. + + Examples + -------- + >>> x = (col("id") % 3).alias("x") + >>> y = (randn(42) + x * 10).alias("y") + >>> df = spark.range(0, 1000, 1, 1).select(x, y) + >>> df.select(regr_slope("y", "x")).first() + Row(regr_slope(y, x)=10.040390844891048) + """ + return _invoke_function_over_columns("regr_slope", y, x) + + +@try_remote_functions +def regr_sxx(y: "ColumnOrName", x: "ColumnOrName") -> Column: + """ + Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(x) for non-null pairs + in a group, where `y` is the dependent variable and `x` is the independent variable. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + y : :class:`~pyspark.sql.Column` or str + the dependent variable. + x : :class:`~pyspark.sql.Column` or str + the independent variable. + + Returns + ------- + :class:`~pyspark.sql.Column` + REGR_COUNT(y, x) * VAR_POP(x) for non-null pairs in a group. + + Examples + -------- + >>> x = (col("id") % 3).alias("x") + >>> y = (randn(42) + x * 10).alias("y") + >>> df = spark.range(0, 1000, 1, 1).select(x, y) + >>> df.select(regr_sxx("y", "x")).first() + Row(regr_sxx(y, x)=666.9989999999996) + """ + return _invoke_function_over_columns("regr_sxx", y, x) + + +@try_remote_functions +def regr_sxy(y: "ColumnOrName", x: "ColumnOrName") -> Column: + """ + Aggregate function: returns REGR_COUNT(y, x) * COVAR_POP(y, x) for non-null pairs + in a group, where `y` is the dependent variable and `x` is the independent variable. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + y : :class:`~pyspark.sql.Column` or str + the dependent variable. + x : :class:`~pyspark.sql.Column` or str + the independent variable. + + Returns + ------- + :class:`~pyspark.sql.Column` + REGR_COUNT(y, x) * COVAR_POP(y, x) for non-null pairs in a group. + + Examples + -------- + >>> x = (col("id") % 3).alias("x") + >>> y = (randn(42) + x * 10).alias("y") + >>> df = spark.range(0, 1000, 1, 1).select(x, y) + >>> df.select(regr_sxy("y", "x")).first() + Row(regr_sxy(y, x)=6696.93065315148) + """ + return _invoke_function_over_columns("regr_sxy", y, x) + + +@try_remote_functions +def regr_syy(y: "ColumnOrName", x: "ColumnOrName") -> Column: + """ + Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(y) for non-null pairs + in a group, where `y` is the dependent variable and `x` is the independent variable. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + y : :class:`~pyspark.sql.Column` or str + the dependent variable. + x : :class:`~pyspark.sql.Column` or str + the independent variable. + + Returns + ------- + :class:`~pyspark.sql.Column` + REGR_COUNT(y, x) * VAR_POP(y) for non-null pairs in a group. + + Examples + -------- + >>> x = (col("id") % 3).alias("x") + >>> y = (randn(42) + x * 10).alias("y") + >>> df = spark.range(0, 1000, 1, 1).select(x, y) + >>> df.select(regr_syy("y", "x")).first() + Row(regr_syy(y, x)=68250.53503811295) + """ + return _invoke_function_over_columns("regr_syy", y, x) + + @try_remote_functions def skewness(col: "ColumnOrName") -> Column: """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f3c307bf0f7..a36fc3b066d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1043,6 +1043,89 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) + /** + * Aggregate function: returns the average of the independent variable for non-null pairs + * in a group, where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_avgx(y: Column, x: Column): Column = withAggregateFunction { RegrAvgX(y.expr, x.expr) } + + /** + * Aggregate function: returns the average of the independent variable for non-null pairs + * in a group, where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_avgy(y: Column, x: Column): Column = withAggregateFunction { RegrAvgY(y.expr, x.expr) } + + /** + * Aggregate function: returns the number of non-null number pairs + * in a group, where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_count(y: Column, x: Column): Column = withAggregateFunction { RegrCount(y.expr, x.expr) } + + /** + * Aggregate function: returns the intercept of the univariate linear regression line + * for non-null pairs in a group, where `y` is the dependent variable and + * `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_intercept(y: Column, x: Column): Column = + withAggregateFunction { RegrIntercept(y.expr, x.expr) } + + /** + * Aggregate function: returns the coefficient of determination for non-null pairs + * in a group, where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_r2(y: Column, x: Column): Column = withAggregateFunction { RegrR2(y.expr, x.expr) } + + /** + * Aggregate function: returns the slope of the linear regression line for non-null pairs + * in a group, where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_slope(y: Column, x: Column): Column = + withAggregateFunction { RegrSlope(y.expr, x.expr) } + + /** + * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(x) for non-null pairs + * in a group, where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_sxx(y: Column, x: Column): Column = withAggregateFunction { RegrSXX(y.expr, x.expr) } + + /** + * Aggregate function: returns REGR_COUNT(y, x) * COVAR_POP(y, x) for non-null pairs + * in a group, where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_sxy(y: Column, x: Column): Column = withAggregateFunction { RegrSXY(y.expr, x.expr) } + + /** + * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(y) for non-null pairs + * in a group, where `y` is the dependent variable and `x` is the independent variable. + * + * @group agg_funcs + * @since 3.5.0 + */ + def regr_syy(y: Column, x: Column): Column = withAggregateFunction { RegrSYY(y.expr, x.expr) } ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e8c133646e2..52010c6350c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -463,6 +463,19 @@ class DataFrameAggregateSuite extends QueryTest checkAggregatesWithTol(sparkKurtosis, Row(-1.5), absTol) } + test("linear regression") { + checkAnswer(testData2.agg(regr_avgx($"a", $"b")), testData2.selectExpr("regr_avgx(a, b)")) + checkAnswer(testData2.agg(regr_avgy($"a", $"b")), testData2.selectExpr("regr_avgy(a, b)")) + checkAnswer(testData2.agg(regr_count($"a", $"b")), testData2.selectExpr("regr_count(a, b)")) + checkAnswer( + testData2.agg(regr_intercept($"a", $"b")), testData2.selectExpr("regr_intercept(a, b)")) + checkAnswer(testData2.agg(regr_r2($"a", $"b")), testData2.selectExpr("regr_r2(a, b)")) + checkAnswer(testData2.agg(regr_slope($"a", $"b")), testData2.selectExpr("regr_slope(a, b)")) + checkAnswer(testData2.agg(regr_sxx($"a", $"b")), testData2.selectExpr("regr_sxx(a, b)")) + checkAnswer(testData2.agg(regr_sxy($"a", $"b")), testData2.selectExpr("regr_sxy(a, b)")) + checkAnswer(testData2.agg(regr_syy($"a", $"b")), testData2.selectExpr("regr_syy(a, b)")) + } + test("zero moments") { withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true") { val input = Seq((1, 2)).toDF("a", "b") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org