Repository: spark
Updated Branches:
  refs/heads/branch-2.3 d656be74b -> 84a189a34


[SPARK-23177][SQL][PYSPARK][BACKPORT-2.3] Extract zero-parameter UDFs from 
aggregate

## What changes were proposed in this pull request?

We extract Python UDFs in logical aggregate which depends on aggregate 
expression or grouping key in ExtractPythonUDFFromAggregate rule. But Python 
UDFs which don't depend on above expressions should also be extracted to avoid 
the issue reported in the JIRA.

A small code snippet to reproduce that issue looks like:
```python
import pyspark.sql.functions as f

df = spark.createDataFrame([(1,2), (3,4)])
f_udf = f.udf(lambda: str("const_str"))
df2 = df.distinct().withColumn("a", f_udf())
df2.show()
```

Error exception is raised as:
```
: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding 
attribute, tree: pythonUDF0#50
        at 
org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56)
        at 
org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:91)
        at 
org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:90)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
        at 
org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:266)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:256)
        at 
org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:90)
        at 
org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$38.apply(HashAggregateExec.scala:514)
        at 
org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$38.apply(HashAggregateExec.scala:513)
```

This exception raises because `HashAggregateExec` tries to bind the aliased 
Python UDF expression (e.g., `pythonUDF0#50 AS a#44`) to grouping key.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #20379 from viirya/SPARK-23177-backport-2.3.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/84a189a3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/84a189a3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/84a189a3

Branch: refs/heads/branch-2.3
Commit: 84a189a3429c64eeabc7b4e8fc0488ec16742002
Parents: d656be7
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Wed Jan 24 20:29:47 2018 +0900
Committer: hyukjinkwon <gurwls...@gmail.com>
Committed: Wed Jan 24 20:29:47 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                                  | 8 ++++++++
 .../spark/sql/execution/python/ExtractPythonUDFs.scala       | 5 +++--
 2 files changed, 11 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/84a189a3/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 4fee2ec..fbb18c4 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1100,6 +1100,14 @@ class SQLTests(ReusedSQLTestCase):
         rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
         self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
 
+    def test_nonparam_udf_with_aggregate(self):
+        import pyspark.sql.functions as f
+
+        df = self.spark.createDataFrame([(1, 2), (1, 2)])
+        f_udf = f.udf(lambda: "const_str")
+        rows = df.distinct().withColumn("a", f_udf()).collect()
+        self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')])
+
     def test_infer_schema_with_udt(self):
         from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))

http://git-wip-us.apache.org/repos/asf/spark/blob/84a189a3/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 2f53fe7..e6616d0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{FilterExec, 
ProjectExec, SparkPlan}
 
 /**
  * Extracts all the Python UDFs in logical aggregate, which depends on 
aggregate expression or
- * grouping key, evaluate them after aggregate.
+ * grouping key, or doesn't depend on any above expressions, evaluate them 
after aggregate.
  */
 object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
 
@@ -44,7 +44,8 @@ object ExtractPythonUDFFromAggregate extends 
Rule[LogicalPlan] {
 
   private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): 
Boolean = {
     expr.find {
-      e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, 
agg)).isDefined
+      e => e.isInstanceOf[PythonUDF] &&
+        (e.references.isEmpty || e.find(belongAggregate(_, agg)).isDefined)
     }.isDefined
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to