Repository: spark Updated Branches: refs/heads/master d03e0af80 -> 2a8cbfddb
[SPARK-25314][SQL] Fix Python UDF accessing attributes from both side of join in join conditions ## What changes were proposed in this pull request? Thanks for bahchis reporting this. It is more like a follow up work for #16581, this PR fix the scenario of Python UDF accessing attributes from both side of join in join condition. ## How was this patch tested? Add regression tests in PySpark and `BatchEvalPythonExecSuite`. Closes #22326 from xuanyuanking/SPARK-25314. Authored-by: Yuanjian Li <xyliyuanj...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2a8cbfdd Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2a8cbfdd Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2a8cbfdd Branch: refs/heads/master Commit: 2a8cbfddba2a59d144b32910c68c22d0199093fe Parents: d03e0af Author: Yuanjian Li <xyliyuanj...@gmail.com> Authored: Thu Sep 27 15:13:18 2018 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Thu Sep 27 15:13:18 2018 +0800 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 64 ++++++++++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 8 ++- .../spark/sql/catalyst/optimizer/joins.scala | 49 +++++++++++++++ 3 files changed, 119 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2a8cbfdd/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 64a7ceb..b88a655 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -596,6 +596,70 @@ class SQLTests(ReusedSQLTestCase): df = left.crossJoin(right).filter(f("a", "b")) self.assertEqual(df.collect(), [Row(a=1, b=1)]) + def test_udf_in_join_condition(self): + # regression test for SPARK-25314 + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(b=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, f("a", "b")) + with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): + df.collect() + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + self.assertEqual(df.collect(), [Row(a=1, b=1)]) + + def test_udf_in_left_semi_join_condition(self): + # regression test for SPARK-25314 + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, f("a", "b"), "leftsemi") + with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): + df.collect() + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) + + def test_udf_and_common_filter_in_join_condition(self): + # regression test for SPARK-25314 + # test the complex scenario with both udf and common filter + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, [f("a", "b"), left.a1 == right.b1]) + # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)]) + + def test_udf_and_common_filter_in_left_semi_join_condition(self): + # regression test for SPARK-25314 + # test the complex scenario with both udf and common filter + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi") + # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) + + def test_udf_not_supported_in_join_condition(self): + # regression test for SPARK-25314 + # test python udf is not supported in join type besides left_semi and inner join. + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + + def runWithJoinType(join_type, type_string): + with self.assertRaisesRegexp( + AnalysisException, + 'Using PythonUDF.*%s is not supported.' % type_string): + left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect() + runWithJoinType("full", "FullOuter") + runWithJoinType("left", "LeftOuter") + runWithJoinType("right", "RightOuter") + runWithJoinType("leftanti", "LeftAnti") + def test_udf_without_arguments(self): self.spark.catalog.registerFunction("foo", lambda: "bar") [row] = self.spark.sql("SELECT foo()").collect() http://git-wip-us.apache.org/repos/asf/spark/blob/2a8cbfdd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 07a653f..da8009d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -165,7 +165,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, PropagateEmptyRelation) :+ - // The following batch should be executed after batch "Join Reorder" and "LocalRelation". + Batch("Extract PythonUDF From JoinCondition", Once, + PullOutPythonUDFInJoinCondition) :+ + // The following batch should be executed after batch "Join Reorder" "LocalRelation" and + // "Extract PythonUDF From JoinCondition". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :+ Batch("RewriteSubquery", Once, @@ -202,7 +205,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) ReplaceDistinctWithAggregate.ruleName :: PullupCorrelatedPredicates.ruleName :: RewriteCorrelatedScalarSubquery.ruleName :: - RewritePredicateSubquery.ruleName :: Nil + RewritePredicateSubquery.ruleName :: + PullOutPythonUDFInJoinCondition.ruleName :: Nil /** * Optimize all the subqueries inside expression. http://git-wip-us.apache.org/repos/asf/spark/blob/2a8cbfdd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index edbeaf2..7149ede 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ @@ -152,3 +153,51 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) } } + +/** + * PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF + * and pull them out from join condition. For python udf accessing attributes from only one side, + * they are pushed down by operation push down rules. If not (e.g. user disables filter push + * down rules), we need to pull them out in this rule too. + */ +object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper { + def hasPythonUDF(expression: Expression): Boolean = { + expression.collectFirst { case udf: PythonUDF => udf }.isDefined + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case j @ Join(_, _, joinType, condition) + if condition.isDefined && hasPythonUDF(condition.get) => + if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) { + // The current strategy only support InnerLike and LeftSemi join because for other type, + // it breaks SQL semantic if we run the join condition as a filter after join. If we pass + // the plan here, it'll still get a an invalid PythonUDF RuntimeException with message + // `requires attributes from more than one child`, we throw firstly here for better + // readable information. + throw new AnalysisException("Using PythonUDF in join condition of join type" + + s" $joinType is not supported.") + } + // If condition expression contains python udf, it will be moved out from + // the new join conditions. + val (udf, rest) = + splitConjunctivePredicates(condition.get).partition(hasPythonUDF) + val newCondition = if (rest.isEmpty) { + logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," + + s" it will be moved out and the join plan will be turned to cross join.") + None + } else { + Some(rest.reduceLeft(And)) + } + val newJoin = j.copy(condition = newCondition) + joinType match { + case _: InnerLike => Filter(udf.reduceLeft(And), newJoin) + case LeftSemi => + Project( + j.left.output.map(_.toAttribute), + Filter(udf.reduceLeft(And), newJoin.copy(joinType = Inner))) + case _ => + throw new AnalysisException("Using PythonUDF in join condition of join type" + + s" $joinType is not supported.") + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org