This is an automated email from the ASF dual-hosted git repository.

cutlerb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7858e53  [SPARK-28323][SQL][PYTHON] PythonUDF should be able to use in 
join condition
7858e53 is described below

commit 7858e534d3195d532874a3d90121353895ba3f42
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Wed Jul 10 16:29:58 2019 -0700

    [SPARK-28323][SQL][PYTHON] PythonUDF should be able to use in join condition
    
    ## What changes were proposed in this pull request?
    
    There is a bug in `ExtractPythonUDFs` that produces wrong result 
attributes. It causes a failure when using `PythonUDF`s among multiple child 
plans, e.g., join. An example is using `PythonUDF`s in join condition.
    
    ```python
    >>> left = spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, 
a2=2)])
    >>> right = spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, 
b2=1)])
    >>> f = udf(lambda a: a, IntegerType())
    >>> df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
    >>> df.collect()
    19/07/10 12:20:49 ERROR Executor: Exception in task 5.0 in stage 0.0 (TID 5)
    java.lang.ArrayIndexOutOfBoundsException: 1
            at 
org.apache.spark.sql.catalyst.expressions.GenericInternalRow.genericGet(rows.scala:201)
            at 
org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.getAs(rows.scala:35)
            at 
org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.isNullAt(rows.scala:36)
            at 
org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.isNullAt$(rows.scala:36)
            at 
org.apache.spark.sql.catalyst.expressions.GenericInternalRow.isNullAt(rows.scala:195)
            at 
org.apache.spark.sql.catalyst.expressions.JoinedRow.isNullAt(JoinedRow.scala:70)
            ...
    ```
    
    ## How was this patch tested?
    
    Added test.
    
    Closes #25091 from viirya/SPARK-28323.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: Bryan Cutler <cutl...@gmail.com>
---
 python/pyspark/sql/tests/test_udf.py               | 10 +++++++++
 .../sql/execution/python/ExtractPythonUDFs.scala   |  2 +-
 .../scala/org/apache/spark/sql/JoinSuite.scala     | 25 ++++++++++++++++++++++
 3 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index 0dafa18..803d471 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -197,6 +197,8 @@ class UDFTests(ReusedSQLTestCase):
         left = self.spark.createDataFrame([Row(a=1)])
         right = self.spark.createDataFrame([Row(b=1)])
         f = udf(lambda a, b: a == b, BooleanType())
+        # The udf uses attributes from both sides of join, so it is pulled out 
as Filter +
+        # Cross join.
         df = left.join(right, f("a", "b"))
         with self.assertRaisesRegexp(AnalysisException, 'Detected implicit 
cartesian product'):
             df.collect()
@@ -243,6 +245,14 @@ class UDFTests(ReusedSQLTestCase):
         runWithJoinType("leftanti", "LeftAnti")
         runWithJoinType("leftsemi", "LeftSemi")
 
+    def test_udf_as_join_condition(self):
+        left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, 
a1=2, a2=2)])
+        right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, 
b1=3, b2=1)])
+        f = udf(lambda a: a, IntegerType())
+
+        df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
+        self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
+
     def test_udf_without_arguments(self):
         self.spark.catalog.registerFunction("foo", lambda: "bar")
         [row] = self.spark.sql("SELECT foo()").collect()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 58fe7d5..fc4ded3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -179,7 +179,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with 
PredicateHelper {
             validUdfs.forall(PythonUDF.isScalarPythonUDF),
             "Can only extract scalar vectorized udf or sql batch udf")
 
-          val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
+          val resultAttrs = validUdfs.zipWithIndex.map { case (u, i) =>
             AttributeReference(s"pythonUDF$i", u.dataType)()
           }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 38c634e..32cddc9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -28,6 +28,7 @@ import 
org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
 import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
 import org.apache.spark.sql.execution.{BinaryExecNode, SortExec}
 import org.apache.spark.sql.execution.joins._
+import org.apache.spark.sql.execution.python.BatchEvalPythonExec
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.StructType
@@ -969,4 +970,28 @@ class JoinSuite extends QueryTest with SharedSQLContext {
           Seq(Row(0.0d, 0.0/0.0)))))
     }
   }
+
+  test("SPARK-28323: PythonUDF should be able to use in join condition") {
+    import IntegratedUDFTestUtils._
+
+    assume(shouldTestPythonUDFs)
+
+    val pythonTestUDF = TestPythonUDF(name = "udf")
+
+    val left = Seq((1, 2), (2, 3)).toDF("a", "b")
+    val right = Seq((1, 2), (3, 4)).toDF("c", "d")
+    val df = left.join(right, pythonTestUDF($"a") === pythonTestUDF($"c"))
+
+    val joinNode = 
df.queryExecution.executedPlan.find(_.isInstanceOf[BroadcastHashJoinExec])
+    assert(joinNode.isDefined)
+
+    // There are two PythonUDFs which use attribute from left and right of 
join, individually.
+    // So two PythonUDFs should be evaluated before the join operator, at left 
and right side.
+    val pythonEvals = joinNode.get.collect {
+      case p: BatchEvalPythonExec => p
+    }
+    assert(pythonEvals.size == 2)
+
+    checkAnswer(df, Row(1, 2, 1, 2) :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to