Repository: spark
Updated Branches:
  refs/heads/branch-2.4 443d12dbb -> 0763b758d


[SPARK-25601][PYTHON] Register Grouped aggregate UDF Vectorized UDFs for SQL 
Statement

## What changes were proposed in this pull request?

This PR proposes to register Grouped aggregate UDF Vectorized UDFs for SQL 
Statement, for instance:

```python
from pyspark.sql.functions import pandas_udf, PandasUDFType

pandas_udf("integer", PandasUDFType.GROUPED_AGG)
def sum_udf(v):
    return v.sum()

spark.udf.register("sum_udf", sum_udf)
q = "SELECT v2, sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) 
GROUP BY v2"
spark.sql(q).show()
```

```
+---+-----------+
| v2|sum_udf(v1)|
+---+-----------+
|  1|          1|
|  0|          5|
+---+-----------+
```

## How was this patch tested?

Manual test and unit test.

Closes #22620 from HyukjinKwon/SPARK-25601.

Authored-by: hyukjinkwon <gurwls...@apache.org>
Signed-off-by: hyukjinkwon <gurwls...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0763b758
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0763b758
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0763b758

Branch: refs/heads/branch-2.4
Commit: 0763b758de55fd14d7da4832d01b5713e582b257
Parents: 443d12d
Author: hyukjinkwon <gurwls...@apache.org>
Authored: Thu Oct 4 09:36:23 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Thu Oct 4 09:43:42 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py | 20 ++++++++++++++++++--
 python/pyspark/sql/udf.py   | 15 +++++++++++++--
 2 files changed, 31 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0763b758/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 690035a..e991032 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -5595,8 +5595,9 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
 
         foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
         with QuietTest(self.sc):
-            with self.assertRaisesRegexp(ValueError, 'f must be either 
SQL_BATCHED_UDF or '
-                                                     'SQL_SCALAR_PANDAS_UDF'):
+            with self.assertRaisesRegexp(
+                    ValueError,
+                    
'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'):
                 self.spark.catalog.registerFunction("foo_udf", foo_udf)
 
     def test_decorator(self):
@@ -6412,6 +6413,21 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
                     'mixture.*aggregate function.*group aggregate pandas UDF'):
                 df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
 
+    def test_register_vectorized_udf_basic(self):
+        from pyspark.sql.functions import pandas_udf
+        from pyspark.rdd import PythonEvalType
+
+        sum_pandas_udf = pandas_udf(
+            lambda v: v.sum(), "integer", 
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
+
+        self.assertEqual(sum_pandas_udf.evalType, 
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
+        group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", 
sum_pandas_udf)
+        self.assertEqual(group_agg_pandas_udf.evalType, 
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
+        q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) 
tbl(v1, v2) GROUP BY v2"
+        actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
+        expected = [1, 5]
+        self.assertEqual(actual, expected)
+
 
 @unittest.skipIf(
     not _have_pandas or not _have_pyarrow,

http://git-wip-us.apache.org/repos/asf/spark/blob/0763b758/python/pyspark/sql/udf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 9dbe49b..58f4e0d 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -298,6 +298,15 @@ class UDFRegistration(object):
             >>> spark.sql("SELECT add_one(id) FROM range(3)").collect()  # 
doctest: +SKIP
             [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
 
+            >>> @pandas_udf("integer", PandasUDFType.GROUPED_AGG)  # doctest: 
+SKIP
+            ... def sum_udf(v):
+            ...     return v.sum()
+            ...
+            >>> _ = spark.udf.register("sum_udf", sum_udf)  # doctest: +SKIP
+            >>> q = "SELECT sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) 
tbl(v1, v2) GROUP BY v2"
+            >>> spark.sql(q).collect()  # doctest: +SKIP
+            [Row(sum_udf(v1)=1), Row(sum_udf(v1)=5)]
+
             .. note:: Registration for a user-defined function (case 2.) was 
added from
                 Spark 2.3.0.
         """
@@ -310,9 +319,11 @@ class UDFRegistration(object):
                     "Invalid returnType: data type can not be specified when f 
is"
                     "a user-defined function, but got %s." % returnType)
             if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
-                                  PythonEvalType.SQL_SCALAR_PANDAS_UDF]:
+                                  PythonEvalType.SQL_SCALAR_PANDAS_UDF,
+                                  PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]:
                 raise ValueError(
-                    "Invalid f: f must be either SQL_BATCHED_UDF or 
SQL_SCALAR_PANDAS_UDF")
+                    "Invalid f: f must be SQL_BATCHED_UDF, 
SQL_SCALAR_PANDAS_UDF or "
+                    "SQL_GROUPED_AGG_PANDAS_UDF")
             register_udf = UserDefinedFunction(f.func, 
returnType=f.returnType, name=name,
                                                evalType=f.evalType,
                                                deterministic=f.deterministic)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to