This is an automated email from the ASF dual-hosted git repository. cutlerb pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7858e53 [SPARK-28323][SQL][PYTHON] PythonUDF should be able to use in join condition 7858e53 is described below commit 7858e534d3195d532874a3d90121353895ba3f42 Author: Liang-Chi Hsieh <vii...@gmail.com> AuthorDate: Wed Jul 10 16:29:58 2019 -0700 [SPARK-28323][SQL][PYTHON] PythonUDF should be able to use in join condition ## What changes were proposed in this pull request? There is a bug in `ExtractPythonUDFs` that produces wrong result attributes. It causes a failure when using `PythonUDF`s among multiple child plans, e.g., join. An example is using `PythonUDF`s in join condition. ```python >>> left = spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) >>> right = spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) >>> f = udf(lambda a: a, IntegerType()) >>> df = left.join(right, [f("a") == f("b"), left.a1 == right.b1]) >>> df.collect() 19/07/10 12:20:49 ERROR Executor: Exception in task 5.0 in stage 0.0 (TID 5) java.lang.ArrayIndexOutOfBoundsException: 1 at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.genericGet(rows.scala:201) at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.getAs(rows.scala:35) at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.isNullAt(rows.scala:36) at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.isNullAt$(rows.scala:36) at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.isNullAt(rows.scala:195) at org.apache.spark.sql.catalyst.expressions.JoinedRow.isNullAt(JoinedRow.scala:70) ... ``` ## How was this patch tested? Added test. Closes #25091 from viirya/SPARK-28323. Authored-by: Liang-Chi Hsieh <vii...@gmail.com> Signed-off-by: Bryan Cutler <cutl...@gmail.com> --- python/pyspark/sql/tests/test_udf.py | 10 +++++++++ .../sql/execution/python/ExtractPythonUDFs.scala | 2 +- .../scala/org/apache/spark/sql/JoinSuite.scala | 25 ++++++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 0dafa18..803d471 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -197,6 +197,8 @@ class UDFTests(ReusedSQLTestCase): left = self.spark.createDataFrame([Row(a=1)]) right = self.spark.createDataFrame([Row(b=1)]) f = udf(lambda a, b: a == b, BooleanType()) + # The udf uses attributes from both sides of join, so it is pulled out as Filter + + # Cross join. df = left.join(right, f("a", "b")) with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): df.collect() @@ -243,6 +245,14 @@ class UDFTests(ReusedSQLTestCase): runWithJoinType("leftanti", "LeftAnti") runWithJoinType("leftsemi", "LeftSemi") + def test_udf_as_join_condition(self): + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a: a, IntegerType()) + + df = left.join(right, [f("a") == f("b"), left.a1 == right.b1]) + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)]) + def test_udf_without_arguments(self): self.spark.catalog.registerFunction("foo", lambda: "bar") [row] = self.spark.sql("SELECT foo()").collect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 58fe7d5..fc4ded3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -179,7 +179,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { validUdfs.forall(PythonUDF.isScalarPythonUDF), "Can only extract scalar vectorized udf or sql batch udf") - val resultAttrs = udfs.zipWithIndex.map { case (u, i) => + val resultAttrs = validUdfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 38c634e..32cddc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} import org.apache.spark.sql.execution.{BinaryExecNode, SortExec} import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.python.BatchEvalPythonExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -969,4 +970,28 @@ class JoinSuite extends QueryTest with SharedSQLContext { Seq(Row(0.0d, 0.0/0.0))))) } } + + test("SPARK-28323: PythonUDF should be able to use in join condition") { + import IntegratedUDFTestUtils._ + + assume(shouldTestPythonUDFs) + + val pythonTestUDF = TestPythonUDF(name = "udf") + + val left = Seq((1, 2), (2, 3)).toDF("a", "b") + val right = Seq((1, 2), (3, 4)).toDF("c", "d") + val df = left.join(right, pythonTestUDF($"a") === pythonTestUDF($"c")) + + val joinNode = df.queryExecution.executedPlan.find(_.isInstanceOf[BroadcastHashJoinExec]) + assert(joinNode.isDefined) + + // There are two PythonUDFs which use attribute from left and right of join, individually. + // So two PythonUDFs should be evaluated before the join operator, at left and right side. + val pythonEvals = joinNode.get.collect { + case p: BatchEvalPythonExec => p + } + assert(pythonEvals.size == 2) + + checkAnswer(df, Row(1, 2, 1, 2) :: Nil) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org