This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e6a0385 [SPARK-28422][SQL][PYTHON] GROUPED_AGG pandas_udf should work without group by clause e6a0385 is described below commit e6a0385289f2d2fec05d3fb5f798903de292c381 Author: Liang-Chi Hsieh <vii...@gmail.com> AuthorDate: Wed Aug 14 00:32:33 2019 +0900 [SPARK-28422][SQL][PYTHON] GROUPED_AGG pandas_udf should work without group by clause ## What changes were proposed in this pull request? A GROUPED_AGG pandas python udf can't work, if without group by clause, like `select udf(id) from table`. This doesn't match with aggregate function like sum, count..., and also dataset API like `df.agg(udf(df['id']))`. When we parse a udf (or an aggregate function) like that from SQL syntax, it is known as a function in a project. `GlobalAggregates` rule in analysis makes such project as aggregate, by looking for aggregate expressions. At the moment, we should also look for GROUPED_AGG pandas python udf. ## How was this patch tested? Added tests. Closes #25352 from viirya/SPARK-28422. Authored-by: Liang-Chi Hsieh <vii...@gmail.com> Signed-off-by: HyukjinKwon <gurwls...@apache.org> --- .../sql/tests/test_pandas_udf_grouped_agg.py | 15 ++++++++++++++ .../spark/sql/catalyst/analysis/Analyzer.scala | 5 ++++- .../apache/spark/sql/catalyst/plans/PlanTest.scala | 2 ++ .../python/BatchEvalPythonExecSuite.scala | 24 +++++++++++++++++++++- 4 files changed, 44 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py index 041b2b5..6d460df 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -474,6 +474,21 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase): result = df.groupBy('id').agg(f(df['x']).alias('sum')).collect() self.assertEqual(result, expected) + def test_grouped_without_group_by_clause(self): + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def max_udf(v): + return v.max() + + df = self.spark.range(0, 100) + self.spark.udf.register('max_udf', max_udf) + + with self.tempView("table"): + df.createTempView('table') + + agg1 = df.agg(max_udf(df['id'])) + agg2 = self.spark.sql("select max_udf(id) from table") + assert_frame_equal(agg1.toPandas(), agg2.toPandas()) + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_udf_grouped_agg import * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f8eef0c..5a04d57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1799,15 +1799,18 @@ class Analyzer( def containsAggregates(exprs: Seq[Expression]): Boolean = { // Collect all Windowed Aggregate Expressions. - val windowedAggExprs = exprs.flatMap { expr => + val windowedAggExprs: Set[Expression] = exprs.flatMap { expr => expr.collect { case WindowExpression(ae: AggregateExpression, _) => ae + case WindowExpression(e: PythonUDF, _) if PythonUDF.isGroupedAggPandasUDF(e) => e } }.toSet // Find the first Aggregate Expression that is not Windowed. exprs.exists(_.collectFirst { case ae: AggregateExpression if !windowedAggExprs.contains(ae) => ae + case e: PythonUDF if PythonUDF.isGroupedAggPandasUDF(e) && + !windowedAggExprs.contains(e) => e }.isDefined) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6e2a842..08f1f87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -81,6 +81,8 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => ae.copy(resultId = ExprId(0)) case lv: NamedLambdaVariable => lv.copy(exprId = ExprId(0), value = null) + case udf: PythonUDF => + udf.copy(resultId = ExprId(0)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 289cc66..ac5752b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{BooleanType, DoubleType} class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.newProductEncoder @@ -100,6 +100,21 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { } assert(qualifiedPlanNodes.size == 1) } + + test("SPARK-28422: GROUPED_AGG pandas_udf should work without group by clause") { + val aggPandasUdf = new MyDummyGroupedAggPandasUDF + spark.udf.registerPython("dummyGroupedAggPandasUDF", aggPandasUdf) + + withTempView("table") { + val df = spark.range(0, 100) + df.createTempView("table") + + val agg1 = df.agg(aggPandasUdf(df("id"))) + val agg2 = sql("select dummyGroupedAggPandasUDF(id) from table") + + comparePlans(agg1.queryExecution.optimizedPlan, agg2.queryExecution.optimizedPlan) + } + } } // This Python UDF is dummy and just for testing. Unable to execute. @@ -119,6 +134,13 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( pythonEvalType = PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) +class MyDummyGroupedAggPandasUDF extends UserDefinedPythonFunction( + name = "dummyGroupedAggPandasUDF", + func = new DummyUDF, + dataType = DoubleType, + pythonEvalType = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + udfDeterministic = true) + class MyDummyScalarPandasUDF extends UserDefinedPythonFunction( name = "dummyScalarPandasUDF", func = new DummyUDF, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org