This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0fe361e3a1b [SPARK-42115][SQL] Push down limit through Python UDFs 0fe361e3a1b is described below commit 0fe361e3a1b5be04114402b78e62dd010703477b Author: Kelvin Jiang <kelvin.ji...@databricks.com> AuthorDate: Thu Feb 2 09:00:51 2023 +0800 [SPARK-42115][SQL] Push down limit through Python UDFs ### What changes were proposed in this pull request? This PR adds cases in LimitPushDown to push limits through Python UDFs. In order to allow for this, we need to call LimitPushDown in SparkOptimizer after the "Extract Python UDFs" batch. We also add PushProjectionThroughLimit afterwards in order to plan CollectLimit. ### Why are the changes needed? Right now, LimitPushdown does not push limits through Python UDFs, which means that expensive Python UDFs can be run on potentially large amounts of input. This PR adds this capability, while ensuring that a GlobalLimit - LocalLimit pattern stays at the top in order to trigger the CollectLimit code path. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added a UT. Closes #39842 from kelvinjian-db/SPARK-42115-limit-through-python-udfs. Authored-by: Kelvin Jiang <kelvin.ji...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/tests/pandas/test_pandas_udf_scalar.py | 1 + .../spark/sql/catalyst/optimizer/Optimizer.scala | 9 +++++ .../plans/logical/pythonLogicalOperators.scala | 3 ++ .../spark/sql/catalyst/trees/TreePatterns.scala | 1 + .../spark/sql/execution/SparkOptimizer.scala | 2 + .../execution/python/ExtractPythonUDFsSuite.scala | 43 ++++++++++++++++++++++ 6 files changed, 59 insertions(+) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py index cbb26e45d2f..33c957fac58 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py @@ -981,6 +981,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): with self.assertRaisesRegex(Exception, "reached finally block"): self.spark.range(1).select(test_close(col("id"))).collect() + @unittest.skip("LimitPushDown should push limits through Python UDFs so this won't occur") def test_scalar_iter_udf_close_early(self): tmp_dir = tempfile.mkdtemp() try: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 627e3952480..1233f2207f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -766,6 +766,15 @@ object LimitPushDown extends Rule[LogicalPlan] { // Push down local limit 1 if join type is LeftSemiOrAnti and join condition is empty. case j @ Join(_, right, LeftSemiOrAnti(_), None, _) if !right.maxRows.exists(_ <= 1) => j.copy(right = maybePushLocalLimit(Literal(1, IntegerType), right)) + // Push down limits through Python UDFs. + case LocalLimit(le, udf: BatchEvalPython) => + LocalLimit(le, udf.copy(child = maybePushLocalLimit(le, udf.child))) + case LocalLimit(le, p @ Project(_, udf: BatchEvalPython)) => + LocalLimit(le, p.copy(child = udf.copy(child = maybePushLocalLimit(le, udf.child)))) + case LocalLimit(le, udf: ArrowEvalPython) => + LocalLimit(le, udf.copy(child = maybePushLocalLimit(le, udf.child))) + case LocalLimit(le, p @ Project(_, udf: ArrowEvalPython)) => + LocalLimit(le, p.copy(child = udf.copy(child = maybePushLocalLimit(le, udf.child)))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index e97ff7808f1..1ce6808be60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF} +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.types.StructType @@ -141,6 +142,8 @@ trait BaseEvalPython extends UnaryNode { override def output: Seq[Attribute] = child.output ++ resultAttrs override def producedAttributes: AttributeSet = AttributeSet(resultAttrs) + + final override val nodePatterns: Seq[TreePattern] = Seq(EVAL_PYTHON_UDF) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 3957ad4af2d..48db1a4408d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -103,6 +103,7 @@ object TreePattern extends Enumeration { val COMMAND: Value = Value val CTE: Value = Value val DISTINCT_LIKE: Value = Value + val EVAL_PYTHON_UDF: Value = Value val EVENT_TIME_WATERMARK: Value = Value val EXCEPT: Value = Value val FILTER: Value = Value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index c61fd9ce10f..06e3888e7de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -79,7 +79,9 @@ class SparkOptimizer( // The eval-python node may be between Project/Filter and the scan node, which breaks // column pruning and filter push-down. Here we rerun the related optimizer rules. ColumnPruning, + LimitPushDown, PushPredicateThroughNonJoin, + PushProjectionThroughLimit, RemoveNoopOperators) :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) :+ Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala index 8519357dab0..0ab8691801d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.python +import org.apache.spark.sql.catalyst.plans.logical.{ArrowEvalPython, BatchEvalPython, Limit, LocalLimit} import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan @@ -193,5 +194,47 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession { assert(pythonEvalNodes5.size == 1) assert(pythonEvalNodes5.head.udfs.size == 2) } + + test("Infers LocalLimit for Python evaluator") { + val df = Seq(("Hello", 4), ("World", 8)).toDF("a", "b") + + // Check that PushProjectionThroughLimit brings GlobalLimit - LocalLimit to the top (for + // CollectLimit) and that LimitPushDown keeps LocalLimit under UDF. + val df2 = df.limit(1).select(batchedPythonUDF(col("b"))) + assert(df2.queryExecution.optimizedPlan match { + case Limit(_, _) => true + }) + assert(df2.queryExecution.optimizedPlan.find { + case b: BatchEvalPython => b.child.isInstanceOf[LocalLimit] + case _ => false + }.isDefined) + + val df3 = df.limit(1).select(scalarPandasUDF(col("b"))) + assert(df3.queryExecution.optimizedPlan match { + case Limit(_, _) => true + }) + assert(df3.queryExecution.optimizedPlan.find { + case a: ArrowEvalPython => a.child.isInstanceOf[LocalLimit] + case _ => false + }.isDefined) + + val df4 = df.limit(1).select(batchedPythonUDF(col("b")), scalarPandasUDF(col("b"))) + assert(df4.queryExecution.optimizedPlan match { + case Limit(_, _) => true + }) + val evalsWithLimit = df4.queryExecution.optimizedPlan.collect { + case b: BatchEvalPython if b.child.isInstanceOf[LocalLimit] => b + case a: ArrowEvalPython if a.child.isInstanceOf[LocalLimit] => a + } + assert(evalsWithLimit.length == 2) + + // Check that LimitPushDown properly pushes LocalLimit past EvalPython operators. + val df5 = df.select(batchedPythonUDF(col("b")), scalarPandasUDF(col("b"))).limit(1) + df5.queryExecution.optimizedPlan.foreach { + case b: BatchEvalPython => assert(b.child.isInstanceOf[LocalLimit]) + case a: ArrowEvalPython => assert(a.child.isInstanceOf[LocalLimit]) + case _ => + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org