Repository: spark Updated Branches: refs/heads/branch-2.4 443d12dbb -> 0763b758d
[SPARK-25601][PYTHON] Register Grouped aggregate UDF Vectorized UDFs for SQL Statement ## What changes were proposed in this pull request? This PR proposes to register Grouped aggregate UDF Vectorized UDFs for SQL Statement, for instance: ```python from pyspark.sql.functions import pandas_udf, PandasUDFType pandas_udf("integer", PandasUDFType.GROUPED_AGG) def sum_udf(v): return v.sum() spark.udf.register("sum_udf", sum_udf) q = "SELECT v2, sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" spark.sql(q).show() ``` ``` +---+-----------+ | v2|sum_udf(v1)| +---+-----------+ | 1| 1| | 0| 5| +---+-----------+ ``` ## How was this patch tested? Manual test and unit test. Closes #22620 from HyukjinKwon/SPARK-25601. Authored-by: hyukjinkwon <gurwls...@apache.org> Signed-off-by: hyukjinkwon <gurwls...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0763b758 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0763b758 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0763b758 Branch: refs/heads/branch-2.4 Commit: 0763b758de55fd14d7da4832d01b5713e582b257 Parents: 443d12d Author: hyukjinkwon <gurwls...@apache.org> Authored: Thu Oct 4 09:36:23 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Thu Oct 4 09:43:42 2018 +0800 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 20 ++++++++++++++++++-- python/pyspark/sql/udf.py | 15 +++++++++++++-- 2 files changed, 31 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0763b758/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 690035a..e991032 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5595,8 +5595,9 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or ' - 'SQL_SCALAR_PANDAS_UDF'): + with self.assertRaisesRegexp( + ValueError, + 'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'): self.spark.catalog.registerFunction("foo_udf", foo_udf) def test_decorator(self): @@ -6412,6 +6413,21 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + def test_register_vectorized_udf_basic(self): + from pyspark.sql.functions import pandas_udf + from pyspark.rdd import PythonEvalType + + sum_pandas_udf = pandas_udf( + lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) + + self.assertEqual(sum_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) + group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", sum_pandas_udf) + self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) + q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" + actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect())) + expected = [1, 5] + self.assertEqual(actual, expected) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, http://git-wip-us.apache.org/repos/asf/spark/blob/0763b758/python/pyspark/sql/udf.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 9dbe49b..58f4e0d 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -298,6 +298,15 @@ class UDFRegistration(object): >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + >>> @pandas_udf("integer", PandasUDFType.GROUPED_AGG) # doctest: +SKIP + ... def sum_udf(v): + ... return v.sum() + ... + >>> _ = spark.udf.register("sum_udf", sum_udf) # doctest: +SKIP + >>> q = "SELECT sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" + >>> spark.sql(q).collect() # doctest: +SKIP + [Row(sum_udf(v1)=1), Row(sum_udf(v1)=5)] + .. note:: Registration for a user-defined function (case 2.) was added from Spark 2.3.0. """ @@ -310,9 +319,11 @@ class UDFRegistration(object): "Invalid returnType: data type can not be specified when f is" "a user-defined function, but got %s." % returnType) if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_SCALAR_PANDAS_UDF]: + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]: raise ValueError( - "Invalid f: f must be either SQL_BATCHED_UDF or SQL_SCALAR_PANDAS_UDF") + "Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF or " + "SQL_GROUPED_AGG_PANDAS_UDF") register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, evalType=f.evalType, deterministic=f.deterministic) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org