This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 77c7e91  [SPARK-28445][SQL][PYTHON] Fix error when PythonUDF is used 
in both group by and aggregate expression
77c7e91 is described below

commit 77c7e91e029a9a70678435acb141154f2f51882e
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Fri Aug 2 19:47:29 2019 +0900

    [SPARK-28445][SQL][PYTHON] Fix error when PythonUDF is used in both group 
by and aggregate expression
    
    ## What changes were proposed in this pull request?
    
    When PythonUDF is used in group by, and it is also in aggregate expression, 
like
    
    ```
    SELECT pyUDF(a + 1), COUNT(b) FROM testData GROUP BY pyUDF(a + 1)
    ```
    
    It causes analysis exception in `CheckAnalysis`, like
    ```
    org.apache.spark.sql.AnalysisException: expression 'testdata.`a`' is 
neither present in the group by, nor is it an aggregate function.
    ```
    
    First, `CheckAnalysis` can't check semantic equality between PythonUDFs.
    Second, even we make it possible, runtime exception will be thrown
    
    ```
    org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding 
attribute, tree: pythonUDF1#8615
    ...
    Cause: java.lang.RuntimeException: Couldn't find pythonUDF1#8615 in 
[cast(pythonUDF0#8614 as int)#8617,count(b#8599)#8607L]
    ```
    
    The cause is, `ExtractPythonUDFs` extracts both PythonUDFs in group by and 
aggregate expression. The PythonUDFs are two different aliases now in the 
logical aggregate. In runtime, we can't bind the resulting expression in 
aggregate to its grouping and aggregate attributes.
    
    This patch proposes a rule `ExtractGroupingPythonUDFFromAggregate` to 
extract PythonUDFs in group by and evaluate them before aggregate. We replace 
the group by PythonUDF in aggregate expression with aliased result.
    
    The query plan of query `SELECT pyUDF(a + 1), pyUDF(COUNT(b)) FROM testData 
GROUP BY pyUDF(a + 1)`, like
    
    ```
    == Optimized Logical Plan ==
    Project [CAST(pyUDF(cast((a + 1) as string)) AS INT)#8608, 
cast(pythonUDF0#8616 as bigint) AS CAST(pyUDF(cast(count(b) as string)) AS 
BIGINT)#8610L]
    +- BatchEvalPython [pyUDF(cast(agg#8613L as string))], [pythonUDF0#8616]
       +- Aggregate [cast(groupingPythonUDF#8614 as int)], 
[cast(groupingPythonUDF#8614 as int) AS CAST(pyUDF(cast((a + 1) as string)) AS 
INT)#8608, count(b#8599) AS agg#8613L]
          +- Project [pythonUDF0#8615 AS groupingPythonUDF#8614, b#8599]
             +- BatchEvalPython [pyUDF(cast((a#8598 + 1) as string))], 
[pythonUDF0#8615]
                +- LocalRelation [a#8598, b#8599]
    
    == Physical Plan ==
    *(3) Project [CAST(pyUDF(cast((a + 1) as string)) AS INT)#8608, 
cast(pythonUDF0#8616 as bigint) AS CAST(pyUDF(cast(count(b) as string)) AS 
BIGINT)#8610L]
    +- BatchEvalPython [pyUDF(cast(agg#8613L as string))], [pythonUDF0#8616]
       +- *(2) HashAggregate(keys=[cast(groupingPythonUDF#8614 as int)#8617], 
functions=[count(b#8599)], output=[CAST(pyUDF(cast((a + 1) as string)) AS 
INT)#8608, agg#8613L])
          +- Exchange hashpartitioning(cast(groupingPythonUDF#8614 as 
int)#8617, 5), true
             +- *(1) HashAggregate(keys=[cast(groupingPythonUDF#8614 as int) AS 
cast(groupingPythonUDF#8614 as int)#8617], functions=[partial_count(b#8599)], 
output=[cast(groupingPythonUDF#8614 as int)#8617, count#8619L])
                +- *(1) Project [pythonUDF0#8615 AS groupingPythonUDF#8614, 
b#8599]
                   +- BatchEvalPython [pyUDF(cast((a#8598 + 1) as string))], 
[pythonUDF0#8615]
                      +- LocalTableScan [a#8598, b#8599]
    ```
    
    ## How was this patch tested?
    
    Added tests.
    
    Closes #25215 from viirya/SPARK-28445.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: HyukjinKwon <gurwls...@apache.org>
---
 .../spark/sql/catalyst/expressions/PythonUDF.scala |  6 ++
 .../spark/sql/execution/SparkOptimizer.scala       |  6 +-
 .../sql/execution/python/ExtractPythonUDFs.scala   | 63 +++++++++++++++++++
 .../sql/execution/python/PythonUDFSuite.scala      | 71 ++++++++++++++++++++++
 4 files changed, 144 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index 690969e..da2e182 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -67,4 +67,10 @@ case class PythonUDF(
     exprId = resultId)
 
   override def nullable: Boolean = true
+
+  override lazy val canonicalized: Expression = {
+    val canonicalizedChildren = children.map(_.canonicalized)
+    // `resultId` can be seen as cosmetic variation in PythonUDF, as it 
doesn't affect the result.
+    this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren)
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 4ae2194..d4fc92c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog
 import org.apache.spark.sql.catalyst.optimizer.{ColumnPruning, Optimizer, 
PushPredicateThroughNonJoin, RemoveNoopOperators}
 import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
 import org.apache.spark.sql.execution.datasources.SchemaPruning
-import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, 
ExtractPythonUDFs}
+import 
org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, 
ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
 
 class SparkOptimizer(
     catalog: SessionCatalog,
@@ -33,6 +33,8 @@ class SparkOptimizer(
     Batch("Optimize Metadata Only Query", Once, 
OptimizeMetadataOnlyQuery(catalog)) :+
     Batch("Extract Python UDFs", Once,
       ExtractPythonUDFFromAggregate,
+      // This must be executed after `ExtractPythonUDFFromAggregate` and 
before `ExtractPythonUDFs`.
+      ExtractGroupingPythonUDFFromAggregate,
       ExtractPythonUDFs,
       // The eval-python node may be between Project/Filter and the scan node, 
which breaks
       // column pruning and filter push-down. Here we rerun the related 
optimizer rules.
@@ -45,7 +47,7 @@ class SparkOptimizer(
     Batch("User Provided Optimizers", fixedPoint, 
experimentalMethods.extraOptimizations: _*)
 
   override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+
-    ExtractPythonUDFFromAggregate.ruleName :+
+    ExtractPythonUDFFromAggregate.ruleName :+ 
ExtractGroupingPythonUDFFromAggregate.ruleName :+
     ExtractPythonUDFs.ruleName
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index fc4ded3..d49d790 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -81,6 +81,69 @@ object ExtractPythonUDFFromAggregate extends 
Rule[LogicalPlan] {
   }
 }
 
+/**
+ * Extracts PythonUDFs in logical aggregate, which are used in grouping keys, 
evaluate them
+ * before aggregate.
+ * This must be executed after `ExtractPythonUDFFromAggregate` rule and before 
`ExtractPythonUDFs`.
+ */
+object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] {
+  private def hasScalarPythonUDF(e: Expression): Boolean = {
+    e.find(PythonUDF.isScalarPythonUDF).isDefined
+  }
+
+  private def extract(agg: Aggregate): LogicalPlan = {
+    val projList = new ArrayBuffer[NamedExpression]()
+    val groupingExpr = new ArrayBuffer[Expression]()
+    val attributeMap = mutable.HashMap[PythonUDF, NamedExpression]()
+
+    agg.groupingExpressions.foreach { expr =>
+      if (hasScalarPythonUDF(expr)) {
+        val newE = expr transformDown {
+          case p: PythonUDF =>
+            // This is just a sanity check, the rule PullOutNondeterministic 
should
+            // already pull out those nondeterministic expressions.
+            assert(p.udfDeterministic, "Non-determinstic PythonUDFs should not 
appear " +
+              "in grouping expression")
+            val canonicalized = p.canonicalized.asInstanceOf[PythonUDF]
+            if (attributeMap.contains(canonicalized)) {
+              attributeMap(canonicalized)
+            } else {
+              val alias = Alias(p, "groupingPythonUDF")()
+              projList += alias
+              attributeMap += ((canonicalized, alias.toAttribute))
+              alias.toAttribute
+            }
+        }
+        groupingExpr += newE
+      } else {
+        groupingExpr += expr
+      }
+    }
+    val aggExpr = agg.aggregateExpressions.map { expr =>
+      expr.transformUp {
+        // PythonUDF over aggregate was pull out by 
ExtractPythonUDFFromAggregate.
+        // PythonUDF here should be either
+        // 1. Argument of an aggregate function.
+        //    CheckAnalysis guarantees the arguments are deterministic.
+        // 2. PythonUDF in grouping key. Grouping key must be deterministic.
+        // 3. PythonUDF not in grouping key. It is either no arguments or with 
grouping key
+        // in its arguments. Such PythonUDF was pull out by 
ExtractPythonUDFFromAggregate, too.
+        case p: PythonUDF if p.udfDeterministic =>
+          val canonicalized = p.canonicalized.asInstanceOf[PythonUDF]
+          attributeMap.getOrElse(canonicalized, p)
+      }.asInstanceOf[NamedExpression]
+    }
+    agg.copy(
+      groupingExpressions = groupingExpr,
+      aggregateExpressions = aggExpr,
+      child = Project(projList ++ agg.child.output, agg.child))
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+    case agg: Aggregate if 
agg.groupingExpressions.exists(hasScalarPythonUDF(_)) =>
+      extract(agg)
+  }
+}
 
 /**
  * Extracts PythonUDFs from operators, rewriting the query plan so that the 
UDF can be evaluated
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
new file mode 100644
index 0000000..1a971b0
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest}
+import org.apache.spark.sql.functions.count
+import org.apache.spark.sql.test.SharedSQLContext
+
+class PythonUDFSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
+
+  import IntegratedUDFTestUtils._
+
+  val scalaTestUDF = TestScalaUDF(name = "scalaUDF")
+  val pythonTestUDF = TestPythonUDF(name = "pyUDF")
+  assume(shouldTestPythonUDFs)
+
+  lazy val base = Seq(
+    (Some(1), Some(1)), (Some(1), Some(2)), (Some(2), Some(1)),
+    (Some(2), Some(2)), (Some(3), Some(1)), (Some(3), Some(2)),
+    (None, Some(1)), (Some(3), None), (None, None)).toDF("a", "b")
+
+  test("SPARK-28445: PythonUDF as grouping key and aggregate expressions") {
+    val df1 = base.groupBy(scalaTestUDF(base("a") + 1))
+      .agg(scalaTestUDF(base("a") + 1), scalaTestUDF(count(base("b"))))
+    val df2 = base.groupBy(pythonTestUDF(base("a") + 1))
+      .agg(pythonTestUDF(base("a") + 1), pythonTestUDF(count(base("b"))))
+    checkAnswer(df1, df2)
+  }
+
+  test("SPARK-28445: PythonUDF as grouping key and used in aggregate 
expressions") {
+    val df1 = base.groupBy(scalaTestUDF(base("a") + 1))
+      .agg(scalaTestUDF(base("a") + 1) + 1, scalaTestUDF(count(base("b"))))
+    val df2 = base.groupBy(pythonTestUDF(base("a") + 1))
+      .agg(pythonTestUDF(base("a") + 1) + 1, pythonTestUDF(count(base("b"))))
+    checkAnswer(df1, df2)
+  }
+
+  test("SPARK-28445: PythonUDF in aggregate expression has grouping key in its 
arguments") {
+    val df1 = base.groupBy(scalaTestUDF(base("a") + 1))
+      .agg(scalaTestUDF(scalaTestUDF(base("a") + 1)), 
scalaTestUDF(count(base("b"))))
+    val df2 = base.groupBy(pythonTestUDF(base("a") + 1))
+      .agg(pythonTestUDF(pythonTestUDF(base("a") + 1)), 
pythonTestUDF(count(base("b"))))
+    checkAnswer(df1, df2)
+  }
+
+  test("SPARK-28445: PythonUDF over grouping key is argument to aggregate 
function") {
+    val df1 = base.groupBy(scalaTestUDF(base("a") + 1))
+      .agg(scalaTestUDF(scalaTestUDF(base("a") + 1)),
+        scalaTestUDF(count(scalaTestUDF(base("a") + 1))))
+    val df2 = base.groupBy(pythonTestUDF(base("a") + 1))
+      .agg(pythonTestUDF(pythonTestUDF(base("a") + 1)),
+        pythonTestUDF(count(pythonTestUDF(base("a") + 1))))
+    checkAnswer(df1, df2)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to