This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e6a0385  [SPARK-28422][SQL][PYTHON] GROUPED_AGG pandas_udf should work 
without group by clause
e6a0385 is described below

commit e6a0385289f2d2fec05d3fb5f798903de292c381
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Wed Aug 14 00:32:33 2019 +0900

    [SPARK-28422][SQL][PYTHON] GROUPED_AGG pandas_udf should work without group 
by clause
    
    ## What changes were proposed in this pull request?
    
    A GROUPED_AGG pandas python udf can't work, if without group by clause, 
like `select udf(id) from table`.
    
    This doesn't match with aggregate function like sum, count..., and also 
dataset API like `df.agg(udf(df['id']))`.
    
    When we parse a udf (or an aggregate function) like that from SQL syntax, 
it is known as a function in a project. `GlobalAggregates` rule in analysis 
makes such project as aggregate, by looking for aggregate expressions. At the 
moment, we should also look for GROUPED_AGG pandas python udf.
    
    ## How was this patch tested?
    
    Added tests.
    
    Closes #25352 from viirya/SPARK-28422.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: HyukjinKwon <gurwls...@apache.org>
---
 .../sql/tests/test_pandas_udf_grouped_agg.py       | 15 ++++++++++++++
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  5 ++++-
 .../apache/spark/sql/catalyst/plans/PlanTest.scala |  2 ++
 .../python/BatchEvalPythonExecSuite.scala          | 24 +++++++++++++++++++++-
 4 files changed, 44 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py 
b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
index 041b2b5..6d460df 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
@@ -474,6 +474,21 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
         result = df.groupBy('id').agg(f(df['x']).alias('sum')).collect()
         self.assertEqual(result, expected)
 
+    def test_grouped_without_group_by_clause(self):
+        @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+        def max_udf(v):
+            return v.max()
+
+        df = self.spark.range(0, 100)
+        self.spark.udf.register('max_udf', max_udf)
+
+        with self.tempView("table"):
+            df.createTempView('table')
+
+            agg1 = df.agg(max_udf(df['id']))
+            agg2 = self.spark.sql("select max_udf(id) from table")
+            assert_frame_equal(agg1.toPandas(), agg2.toPandas())
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.test_pandas_udf_grouped_agg import *
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index f8eef0c..5a04d57 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1799,15 +1799,18 @@ class Analyzer(
 
     def containsAggregates(exprs: Seq[Expression]): Boolean = {
       // Collect all Windowed Aggregate Expressions.
-      val windowedAggExprs = exprs.flatMap { expr =>
+      val windowedAggExprs: Set[Expression] = exprs.flatMap { expr =>
         expr.collect {
           case WindowExpression(ae: AggregateExpression, _) => ae
+          case WindowExpression(e: PythonUDF, _) if 
PythonUDF.isGroupedAggPandasUDF(e) => e
         }
       }.toSet
 
       // Find the first Aggregate Expression that is not Windowed.
       exprs.exists(_.collectFirst {
         case ae: AggregateExpression if !windowedAggExprs.contains(ae) => ae
+        case e: PythonUDF if PythonUDF.isGroupedAggPandasUDF(e) &&
+          !windowedAggExprs.contains(e) => e
       }.isDefined)
     }
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 6e2a842..08f1f87 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -81,6 +81,8 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { 
self: Suite =>
         ae.copy(resultId = ExprId(0))
       case lv: NamedLambdaVariable =>
         lv.copy(exprId = ExprId(0), value = null)
+      case udf: PythonUDF =>
+        udf.copy(resultId = ExprId(0))
     }
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
index 289cc66..ac5752b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, 
GreaterThan, In}
 import org.apache.spark.sql.execution.{FilterExec, InputAdapter, 
SparkPlanTest, WholeStageCodegenExec}
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.BooleanType
+import org.apache.spark.sql.types.{BooleanType, DoubleType}
 
 class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
   import testImplicits.newProductEncoder
@@ -100,6 +100,21 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with 
SharedSQLContext {
     }
     assert(qualifiedPlanNodes.size == 1)
   }
+
+  test("SPARK-28422: GROUPED_AGG pandas_udf should work without group by 
clause") {
+    val aggPandasUdf = new MyDummyGroupedAggPandasUDF
+    spark.udf.registerPython("dummyGroupedAggPandasUDF", aggPandasUdf)
+
+    withTempView("table") {
+      val df = spark.range(0, 100)
+      df.createTempView("table")
+
+      val agg1 = df.agg(aggPandasUdf(df("id")))
+      val agg2 = sql("select dummyGroupedAggPandasUDF(id) from table")
+
+      comparePlans(agg1.queryExecution.optimizedPlan, 
agg2.queryExecution.optimizedPlan)
+    }
+  }
 }
 
 // This Python UDF is dummy and just for testing. Unable to execute.
@@ -119,6 +134,13 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction(
   pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
   udfDeterministic = true)
 
+class MyDummyGroupedAggPandasUDF extends UserDefinedPythonFunction(
+  name = "dummyGroupedAggPandasUDF",
+  func = new DummyUDF,
+  dataType = DoubleType,
+  pythonEvalType = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+  udfDeterministic = true)
+
 class MyDummyScalarPandasUDF extends UserDefinedPythonFunction(
   name = "dummyScalarPandasUDF",
   func = new DummyUDF,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to