This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 77c7e91 [SPARK-28445][SQL][PYTHON] Fix error when PythonUDF is used in both group by and aggregate expression 77c7e91 is described below commit 77c7e91e029a9a70678435acb141154f2f51882e Author: Liang-Chi Hsieh <vii...@gmail.com> AuthorDate: Fri Aug 2 19:47:29 2019 +0900 [SPARK-28445][SQL][PYTHON] Fix error when PythonUDF is used in both group by and aggregate expression ## What changes were proposed in this pull request? When PythonUDF is used in group by, and it is also in aggregate expression, like ``` SELECT pyUDF(a + 1), COUNT(b) FROM testData GROUP BY pyUDF(a + 1) ``` It causes analysis exception in `CheckAnalysis`, like ``` org.apache.spark.sql.AnalysisException: expression 'testdata.`a`' is neither present in the group by, nor is it an aggregate function. ``` First, `CheckAnalysis` can't check semantic equality between PythonUDFs. Second, even we make it possible, runtime exception will be thrown ``` org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: pythonUDF1#8615 ... Cause: java.lang.RuntimeException: Couldn't find pythonUDF1#8615 in [cast(pythonUDF0#8614 as int)#8617,count(b#8599)#8607L] ``` The cause is, `ExtractPythonUDFs` extracts both PythonUDFs in group by and aggregate expression. The PythonUDFs are two different aliases now in the logical aggregate. In runtime, we can't bind the resulting expression in aggregate to its grouping and aggregate attributes. This patch proposes a rule `ExtractGroupingPythonUDFFromAggregate` to extract PythonUDFs in group by and evaluate them before aggregate. We replace the group by PythonUDF in aggregate expression with aliased result. The query plan of query `SELECT pyUDF(a + 1), pyUDF(COUNT(b)) FROM testData GROUP BY pyUDF(a + 1)`, like ``` == Optimized Logical Plan == Project [CAST(pyUDF(cast((a + 1) as string)) AS INT)#8608, cast(pythonUDF0#8616 as bigint) AS CAST(pyUDF(cast(count(b) as string)) AS BIGINT)#8610L] +- BatchEvalPython [pyUDF(cast(agg#8613L as string))], [pythonUDF0#8616] +- Aggregate [cast(groupingPythonUDF#8614 as int)], [cast(groupingPythonUDF#8614 as int) AS CAST(pyUDF(cast((a + 1) as string)) AS INT)#8608, count(b#8599) AS agg#8613L] +- Project [pythonUDF0#8615 AS groupingPythonUDF#8614, b#8599] +- BatchEvalPython [pyUDF(cast((a#8598 + 1) as string))], [pythonUDF0#8615] +- LocalRelation [a#8598, b#8599] == Physical Plan == *(3) Project [CAST(pyUDF(cast((a + 1) as string)) AS INT)#8608, cast(pythonUDF0#8616 as bigint) AS CAST(pyUDF(cast(count(b) as string)) AS BIGINT)#8610L] +- BatchEvalPython [pyUDF(cast(agg#8613L as string))], [pythonUDF0#8616] +- *(2) HashAggregate(keys=[cast(groupingPythonUDF#8614 as int)#8617], functions=[count(b#8599)], output=[CAST(pyUDF(cast((a + 1) as string)) AS INT)#8608, agg#8613L]) +- Exchange hashpartitioning(cast(groupingPythonUDF#8614 as int)#8617, 5), true +- *(1) HashAggregate(keys=[cast(groupingPythonUDF#8614 as int) AS cast(groupingPythonUDF#8614 as int)#8617], functions=[partial_count(b#8599)], output=[cast(groupingPythonUDF#8614 as int)#8617, count#8619L]) +- *(1) Project [pythonUDF0#8615 AS groupingPythonUDF#8614, b#8599] +- BatchEvalPython [pyUDF(cast((a#8598 + 1) as string))], [pythonUDF0#8615] +- LocalTableScan [a#8598, b#8599] ``` ## How was this patch tested? Added tests. Closes #25215 from viirya/SPARK-28445. Authored-by: Liang-Chi Hsieh <vii...@gmail.com> Signed-off-by: HyukjinKwon <gurwls...@apache.org> --- .../spark/sql/catalyst/expressions/PythonUDF.scala | 6 ++ .../spark/sql/execution/SparkOptimizer.scala | 6 +- .../sql/execution/python/ExtractPythonUDFs.scala | 63 +++++++++++++++++++ .../sql/execution/python/PythonUDFSuite.scala | 71 ++++++++++++++++++++++ 4 files changed, 144 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 690969e..da2e182 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -67,4 +67,10 @@ case class PythonUDF( exprId = resultId) override def nullable: Boolean = true + + override lazy val canonicalized: Expression = { + val canonicalizedChildren = children.map(_.canonicalized) + // `resultId` can be seen as cosmetic variation in PythonUDF, as it doesn't affect the result. + this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 4ae2194..d4fc92c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.{ColumnPruning, Optimizer, PushPredicateThroughNonJoin, RemoveNoopOperators} import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.SchemaPruning -import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs} +import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs} class SparkOptimizer( catalog: SessionCatalog, @@ -33,6 +33,8 @@ class SparkOptimizer( Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDFs", Once, ExtractPythonUDFFromAggregate, + // This must be executed after `ExtractPythonUDFFromAggregate` and before `ExtractPythonUDFs`. + ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFs, // The eval-python node may be between Project/Filter and the scan node, which breaks // column pruning and filter push-down. Here we rerun the related optimizer rules. @@ -45,7 +47,7 @@ class SparkOptimizer( Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+ - ExtractPythonUDFFromAggregate.ruleName :+ + ExtractPythonUDFFromAggregate.ruleName :+ ExtractGroupingPythonUDFFromAggregate.ruleName :+ ExtractPythonUDFs.ruleName /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index fc4ded3..d49d790 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -81,6 +81,69 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { } } +/** + * Extracts PythonUDFs in logical aggregate, which are used in grouping keys, evaluate them + * before aggregate. + * This must be executed after `ExtractPythonUDFFromAggregate` rule and before `ExtractPythonUDFs`. + */ +object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] { + private def hasScalarPythonUDF(e: Expression): Boolean = { + e.find(PythonUDF.isScalarPythonUDF).isDefined + } + + private def extract(agg: Aggregate): LogicalPlan = { + val projList = new ArrayBuffer[NamedExpression]() + val groupingExpr = new ArrayBuffer[Expression]() + val attributeMap = mutable.HashMap[PythonUDF, NamedExpression]() + + agg.groupingExpressions.foreach { expr => + if (hasScalarPythonUDF(expr)) { + val newE = expr transformDown { + case p: PythonUDF => + // This is just a sanity check, the rule PullOutNondeterministic should + // already pull out those nondeterministic expressions. + assert(p.udfDeterministic, "Non-determinstic PythonUDFs should not appear " + + "in grouping expression") + val canonicalized = p.canonicalized.asInstanceOf[PythonUDF] + if (attributeMap.contains(canonicalized)) { + attributeMap(canonicalized) + } else { + val alias = Alias(p, "groupingPythonUDF")() + projList += alias + attributeMap += ((canonicalized, alias.toAttribute)) + alias.toAttribute + } + } + groupingExpr += newE + } else { + groupingExpr += expr + } + } + val aggExpr = agg.aggregateExpressions.map { expr => + expr.transformUp { + // PythonUDF over aggregate was pull out by ExtractPythonUDFFromAggregate. + // PythonUDF here should be either + // 1. Argument of an aggregate function. + // CheckAnalysis guarantees the arguments are deterministic. + // 2. PythonUDF in grouping key. Grouping key must be deterministic. + // 3. PythonUDF not in grouping key. It is either no arguments or with grouping key + // in its arguments. Such PythonUDF was pull out by ExtractPythonUDFFromAggregate, too. + case p: PythonUDF if p.udfDeterministic => + val canonicalized = p.canonicalized.asInstanceOf[PythonUDF] + attributeMap.getOrElse(canonicalized, p) + }.asInstanceOf[NamedExpression] + } + agg.copy( + groupingExpressions = groupingExpr, + aggregateExpressions = aggExpr, + child = Project(projList ++ agg.child.output, agg.child)) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case agg: Aggregate if agg.groupingExpressions.exists(hasScalarPythonUDF(_)) => + extract(agg) + } +} /** * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala new file mode 100644 index 0000000..1a971b0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest} +import org.apache.spark.sql.functions.count +import org.apache.spark.sql.test.SharedSQLContext + +class PythonUDFSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + import IntegratedUDFTestUtils._ + + val scalaTestUDF = TestScalaUDF(name = "scalaUDF") + val pythonTestUDF = TestPythonUDF(name = "pyUDF") + assume(shouldTestPythonUDFs) + + lazy val base = Seq( + (Some(1), Some(1)), (Some(1), Some(2)), (Some(2), Some(1)), + (Some(2), Some(2)), (Some(3), Some(1)), (Some(3), Some(2)), + (None, Some(1)), (Some(3), None), (None, None)).toDF("a", "b") + + test("SPARK-28445: PythonUDF as grouping key and aggregate expressions") { + val df1 = base.groupBy(scalaTestUDF(base("a") + 1)) + .agg(scalaTestUDF(base("a") + 1), scalaTestUDF(count(base("b")))) + val df2 = base.groupBy(pythonTestUDF(base("a") + 1)) + .agg(pythonTestUDF(base("a") + 1), pythonTestUDF(count(base("b")))) + checkAnswer(df1, df2) + } + + test("SPARK-28445: PythonUDF as grouping key and used in aggregate expressions") { + val df1 = base.groupBy(scalaTestUDF(base("a") + 1)) + .agg(scalaTestUDF(base("a") + 1) + 1, scalaTestUDF(count(base("b")))) + val df2 = base.groupBy(pythonTestUDF(base("a") + 1)) + .agg(pythonTestUDF(base("a") + 1) + 1, pythonTestUDF(count(base("b")))) + checkAnswer(df1, df2) + } + + test("SPARK-28445: PythonUDF in aggregate expression has grouping key in its arguments") { + val df1 = base.groupBy(scalaTestUDF(base("a") + 1)) + .agg(scalaTestUDF(scalaTestUDF(base("a") + 1)), scalaTestUDF(count(base("b")))) + val df2 = base.groupBy(pythonTestUDF(base("a") + 1)) + .agg(pythonTestUDF(pythonTestUDF(base("a") + 1)), pythonTestUDF(count(base("b")))) + checkAnswer(df1, df2) + } + + test("SPARK-28445: PythonUDF over grouping key is argument to aggregate function") { + val df1 = base.groupBy(scalaTestUDF(base("a") + 1)) + .agg(scalaTestUDF(scalaTestUDF(base("a") + 1)), + scalaTestUDF(count(scalaTestUDF(base("a") + 1)))) + val df2 = base.groupBy(pythonTestUDF(base("a") + 1)) + .agg(pythonTestUDF(pythonTestUDF(base("a") + 1)), + pythonTestUDF(count(pythonTestUDF(base("a") + 1)))) + checkAnswer(df1, df2) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org